From 043fb4337f7ece36709957c239f2479e9e45f33c Mon Sep 17 00:00:00 2001 From: Dario Mylonopoulos Date: Fri, 1 Dec 2023 15:31:11 +0100 Subject: [PATCH] Added GaussianSplats renderable, GPU sort and gaussian_splatting.py example --- aitviewer/renderables/gaussian_splats.py | 235 ++++++ aitviewer/shaders.py | 26 + .../shaders/gaussian_splatting/common.glsl | 24 + .../shaders/gaussian_splatting/draw.glsl | 63 ++ .../shaders/gaussian_splatting/prepare.glsl | 172 ++++ .../shaders/gaussian_splatting/sort.glsl | 743 ++++++++++++++++++ aitviewer/utils/gpu_sort.py | 161 ++++ examples/gaussian_splatting.py | 104 +++ 8 files changed, 1528 insertions(+) create mode 100644 aitviewer/renderables/gaussian_splats.py create mode 100644 aitviewer/shaders/gaussian_splatting/common.glsl create mode 100644 aitviewer/shaders/gaussian_splatting/draw.glsl create mode 100644 aitviewer/shaders/gaussian_splatting/prepare.glsl create mode 100644 aitviewer/shaders/gaussian_splatting/sort.glsl create mode 100644 aitviewer/utils/gpu_sort.py create mode 100644 examples/gaussian_splatting.py diff --git a/aitviewer/renderables/gaussian_splats.py b/aitviewer/renderables/gaussian_splats.py new file mode 100644 index 0000000..5e2bfe7 --- /dev/null +++ b/aitviewer/renderables/gaussian_splats.py @@ -0,0 +1,235 @@ +import moderngl +import numpy as np +from moderngl_window.opengl.vao import VAO + +from aitviewer.scene.camera import CameraInterface +from aitviewer.scene.node import Node +from aitviewer.shaders import ( + get_gaussian_splat_draw_program, + get_gaussian_splat_prepare_program, +) +from aitviewer.utils.decorators import hooked +from aitviewer.utils.gpu_sort import GpuSort + + +class GaussianSplats(Node): + PREPARE_GROUP_SIZE = 128 + + def __init__(self, splat_positions, splat_shs, splat_opacities, splat_scales, splat_rotations, **kwargs): + super().__init__(**kwargs) + + self.num_splats = splat_positions.shape[0] + self.splat_positions: np.ndarray = splat_positions + self.splat_shs: np.ndarray = splat_shs + self.splat_opacities: np.ndarray = splat_opacities + self.splat_scales: np.ndarray = splat_scales + self.splat_rotations: np.ndarray = splat_rotations + + self.splat_opacity_scale = 1.0 + self.splat_size_scale = 1.0 + + self.backface_culling = False + + self._debug_gui = False + + def is_transparent(self): + return True + + @property + def bounds(self): + return self.current_bounds + + @property + def current_bounds(self): + return self.get_bounds(self.splat_positions) + + @Node.once + def make_renderable(self, ctx: moderngl.Context): + self.prog_prepare = get_gaussian_splat_prepare_program(self.PREPARE_GROUP_SIZE) + self.prog_draw = get_gaussian_splat_draw_program() + + # Buffer for splat positions. + self.splat_positions_buf = ctx.buffer(self.splat_positions.astype(np.float32).tobytes()) + + # Buffer for other splat data: SHs, opacity, scale, rotation. + # + # TODO: In theory we could pre-process rotations and scales and store them + # as a 6 element covariance matrix directly here. + # + # TODO: Currently this only renders with base colors (first spherical harmonics coefficient) + # We need to unswizzle the other coefficients and evaluate them for rendering here. + splat_data = np.hstack( + ( + self.splat_shs[:, :3] * 0.2820948 + 0.5, + self.splat_opacities[:, np.newaxis], + self.splat_scales, + np.zeros((self.num_splats, 1), np.float32), + self.splat_rotations, + ) + ) + self.splat_data_buf = ctx.buffer(splat_data.astype(np.float32).tobytes()) + + # Buffer for splat views. + self.splat_views_buf = ctx.buffer(None, reserve=self.num_splats * 48) + + # Buffers for distances and sorted indices. + self.splat_distances_buf = ctx.buffer(None, reserve=self.num_splats * 4) + self.splat_sorted_indices_buf = ctx.buffer(None, reserve=self.num_splats * 4) + + # Create a vao for rendering a single quad. + indices = np.array((0, 1, 2, 1, 3, 2), np.uint32) + self.vbo_indices = ctx.buffer(indices.tobytes()) + self.vao = VAO() + self.vao.index_buffer(self.vbo_indices) + + self.gpu_sort = GpuSort(ctx, self.num_splats) + + # Time queries for profiling. + self.time_queries = { + "prepare": ctx.query(time=True), + "sort": ctx.query(time=True), + "draw": ctx.query(time=True), + } + self.ctx = ctx + + def render(self, camera: CameraInterface, **kwargs): + # Convert gaussians from 3D to 2D quads. + with self.time_queries["prepare"]: + self.splat_positions_buf.bind_to_storage_buffer(0) + self.splat_data_buf.bind_to_storage_buffer(1) + self.splat_views_buf.bind_to_storage_buffer(2) + self.splat_distances_buf.bind_to_storage_buffer(3) + self.splat_sorted_indices_buf.bind_to_storage_buffer(4) + + self.prog_prepare["u_opacity_scale"] = self.splat_opacity_scale + self.prog_prepare["u_scale2"] = np.square(self.splat_size_scale) + + self.prog_prepare["u_num_splats"] = self.num_splats + V = camera.get_view_matrix() + P = camera.get_projection_matrix() + self.prog_prepare["u_limit"] = 1.3 / P[0, 0] + self.prog_prepare["u_focal"] = kwargs["window_size"][0] * P[0, 0] * 0.5 + + self.prog_prepare["u_world_from_object"].write(self.model_matrix.T.astype("f4").tobytes()) + self.prog_prepare["u_view_from_world"].write(V.T.astype("f4").tobytes()) + self.prog_prepare["u_clip_from_world"].write((P @ V).astype("f4").T.tobytes()) + + num_groups = (self.num_splats + self.PREPARE_GROUP_SIZE - 1) // self.PREPARE_GROUP_SIZE + self.prog_prepare.run(num_groups, 1, 1) + self.ctx.memory_barrier() + + # Sort splats based on distance to camera plane. + with self.time_queries["sort"]: + self.gpu_sort.run(self.ctx, self.splat_distances_buf, self.splat_sorted_indices_buf) + + # Render each splat as a 2D quad with instancing. + with self.time_queries["draw"]: + self.splat_views_buf.bind_to_storage_buffer(2) + self.splat_sorted_indices_buf.bind_to_storage_buffer(3) + + self.prog_draw["u_screen_size"].write(np.array(kwargs["window_size"], np.float32).tobytes()) + + kwargs["fbo"].depth_mask = False + self.vao.render(self.prog_draw, moderngl.TRIANGLES, 6, 0, self.num_splats) + kwargs["fbo"].depth_mask = True + + def gui(self, imgui): + if self._debug_gui: + # Draw debugging stats about marching cubes and rendering time. + total = 0 + for k, v in self.time_queries.items(): + imgui.text(f"{k}: {v.elapsed * 1e-6:5.3f}ms") + total += v.elapsed * 1e-6 + imgui.text(f"Total: {total: 5.3f}ms") + + _, self.splat_size_scale = imgui.drag_float( + "Size", + self.splat_size_scale, + 1e-2, + min_value=0.001, + max_value=10.0, + format="%.3f", + ) + + _, self.splat_opacity_scale = imgui.drag_float( + "Opacity", + self.splat_opacity_scale, + 1e-2, + min_value=0.001, + max_value=10.0, + format="%.3f", + ) + + super().gui(imgui) + + @classmethod + def from_ply(cls, path, sh_degree=3, **kwargs): + with open(path, "rb") as f: + # Read header. + head = f.readline().decode("utf-8").strip().lower() + if head != "ply": + print(head) + raise ValueError(f"Not a ply file: {head}") + + encoding = f.readline().decode("utf-8").strip().lower() + if "binary_little_endian" not in encoding: + raise ValueError(f"Invalid encoding: {encoding}") + + elements = f.readline().decode("utf-8").strip().lower() + count = int(elements.split()[2]) + + # Read until end of header. + while f.readline().decode("utf-8").strip().lower() != "end_header": + pass + + # Number of 32 bit floats used to encode Spherical Harmonics coefficients. + # The last multiplication by 3 is because we have 3 components (RGB) for each coefficient. + sh_coeffs = (sh_degree + 1) * (sh_degree + 1) * 3 + + # Position (vec3), normal (vec3), spherical harmonics (sh_coeffs), opacity (float), + # scale (vec3) and rotation (quaternion). All values are float32 (4 bytes). + size = count * (3 + 3 + sh_coeffs + 1 + 3 + 4) * 4 + + data = f.read(size) + arr = np.frombuffer(data, dtype=np.float32).reshape((count, -1)) + + # Positions. + position = arr[:, :3].copy() + + # Currently we don't need normals for rendering. + # normal = arr[:, 3:6].copy() + + # Spherical harmonic coefficients. + sh = arr[:, 6 : 6 + sh_coeffs].copy() + + # Activate alpha: sigmoid(alpha). + opacity = 1.0 / (1.0 + np.exp(-arr[:, 6 + sh_coeffs])) + + # Exponentiate scale. + scale = np.exp(arr[:, 7 + sh_coeffs : 10 + sh_coeffs]) + + # Normalize quaternions. + rotation = arr[:, 10 + sh_coeffs : 14 + sh_coeffs].copy() + rotation /= np.linalg.norm(rotation, ord=2, axis=1)[..., np.newaxis] + + # Convert from wxyz to xyzw. + rotation = np.roll(rotation, -1, axis=1) + + return cls(position, sh, opacity, scale, rotation, **kwargs) + + @hooked + def release(self): + # vao and vbos are released by Meshes release method. + if self.is_renderable: + self.splat_positions_buf.release() + self.splat_data_buf.release() + self.splat_views_buf.release() + self.splat_distances_buf.release() + self.splat_sorted_indices_buf.release() + self.vao.release() + self.gpu_sort.release() + self.splat_positions = None + self.splat_shs = None + self.splat_opacities = None + self.splat_scales = None + self.splat_rotations = None diff --git a/aitviewer/shaders.py b/aitviewer/shaders.py index 74ce440..304a745 100644 --- a/aitviewer/shaders.py +++ b/aitviewer/shaders.py @@ -129,6 +129,29 @@ def get_marching_cubes_shader(name, BX, BY, BZ, COMPACT_GROUP_SIZE) -> moderngl. return resources.programs.load(ProgramDescription(compute_shader=path, defines=defines)) +@functools.lru_cache() +def get_sort_program(name): + defines = { + "ENTRY_PARALLEL_SORT_" + name: 1, + } + path = os.path.join("gaussian_splatting", "sort.glsl") + return resources.programs.load(ProgramDescription(compute_shader=path, defines=defines)) + + +@functools.lru_cache() +def get_gaussian_splat_prepare_program(PREPARE_GROUP_SIZE): + defines = { + "PREPARE_GROUP_SIZE": PREPARE_GROUP_SIZE, + } + path = os.path.join("gaussian_splatting", "prepare.glsl") + return resources.programs.load(ProgramDescription(compute_shader=path, defines=defines)) + + +@functools.lru_cache() +def get_gaussian_splat_draw_program(): + return load_program(os.path.join("gaussian_splatting", "draw.glsl")) + + def clear_shader_cache(): """Clear all cached shaders.""" funcs = [ @@ -141,6 +164,9 @@ def clear_shader_cache(): get_screen_texture_program, get_chessboard_program, get_marching_cubes_shader, + get_gaussian_splat_draw_program, + get_gaussian_splat_prepare_program, + get_sort_program, ] for f in funcs: f.cache_clear() diff --git a/aitviewer/shaders/gaussian_splatting/common.glsl b/aitviewer/shaders/gaussian_splatting/common.glsl new file mode 100644 index 0000000..4d872df --- /dev/null +++ b/aitviewer/shaders/gaussian_splatting/common.glsl @@ -0,0 +1,24 @@ +// Copyright (C) 2023 ETH Zurich, Manuel Kaufmann, Velko Vechev, Dario Mylonopoulos + +struct Splat { + vec3 position; + vec3 color; + float opacity; + vec3 scale; + vec4 rotation; +}; + +struct SplatData { + vec3 color; + float opacity; + vec3 scale; + float _padding; + vec4 rotation; +}; + +struct SplatView { + vec4 position; + vec2 axis1; + vec2 axis2; + vec4 color; +}; diff --git a/aitviewer/shaders/gaussian_splatting/draw.glsl b/aitviewer/shaders/gaussian_splatting/draw.glsl new file mode 100644 index 0000000..db81e2d --- /dev/null +++ b/aitviewer/shaders/gaussian_splatting/draw.glsl @@ -0,0 +1,63 @@ +#version 450 + +// Copyright (C) 2023 ETH Zurich, Manuel Kaufmann, Velko Vechev, Dario Mylonopoulos + +#include gaussian_splatting/common.glsl + +#if defined VERTEX_SHADER + +layout(std430, binding=2) buffer in_splat_views +{ + SplatView views[]; +} InSplatViews; + +layout(std430, binding=3) buffer in_splat_sorted_indices +{ + uint indices[]; +} InSplatSortedIndices; + +uniform vec2 u_screen_size; + +out vec2 position; +out vec4 color; + +void main() { + uint splat_index = gl_InstanceID; + uint vertex_index = gl_VertexID; + + uint sorted_index = InSplatSortedIndices.indices[splat_index]; + SplatView view = InSplatViews.views[sorted_index]; + if(view.position.w <= 0.0) { + gl_Position.x = uintBitsToFloat(0x7fc00000); // NaN discards the primitive + } else { + vec2 quad = vec2(vertex_index & 1, (vertex_index >> 1) & 1) * 2.0 - 1.0; + quad *= 2.0; + + vec2 delta = (quad.x * view.axis1 + quad.y * view.axis2) * 2.0 / u_screen_size; + vec4 p = view.position; + p.xy += delta * view.position.w; + + color = view.color; + position = quad; + gl_Position = p; + } +} + +#elif defined FRAGMENT_SHADER + +in vec2 position; +in vec4 color; + +out vec4 out_color; + +void main() { + float alpha = clamp(exp(-dot(position, position)) * color.a, 0.0, 1.0); + + if(alpha < 1.0/255.0) { + discard; + } + out_color = vec4(color.rgb, alpha); +} + + +#endif \ No newline at end of file diff --git a/aitviewer/shaders/gaussian_splatting/prepare.glsl b/aitviewer/shaders/gaussian_splatting/prepare.glsl new file mode 100644 index 0000000..6f320ad --- /dev/null +++ b/aitviewer/shaders/gaussian_splatting/prepare.glsl @@ -0,0 +1,172 @@ +#version 450 + +// Copyright (C) 2023 ETH Zurich, Manuel Kaufmann, Velko Vechev, Dario Mylonopoulos + +#include gaussian_splatting/common.glsl + +#define PREPARE_GROUP_SIZE 128 + +layout (local_size_x = PREPARE_GROUP_SIZE, local_size_y = 1, local_size_z = 1) in; + +uniform float u_opacity_scale; +uniform float u_scale2; + +uniform uint u_num_splats; +uniform float u_limit; +uniform float u_focal; + +uniform mat4 u_world_from_object; +uniform mat4 u_view_from_world; +uniform mat4 u_clip_from_world; + +layout(std430, binding=0) buffer in_splat_positions +{ + float positions[]; +} InSplatPositions; + +layout(std430, binding=1) buffer in_splat_data +{ + SplatData data[]; +} InSplatData; + +layout(std430, binding=2) buffer out_splat_views +{ + SplatView views[]; +} OutSplatViews; + +layout(std430, binding=3) buffer out_splat_distances +{ + uint distances[]; +} OutSplatDistances; + +layout(std430, binding=4) buffer out_splat_indices +{ + uint indices[]; +} OutSplatIndices; + +uint floatToSortableUint(float v) { + uint fu = floatBitsToUint(v); + uint mask = -(int(fu >> 31)) | 0x80000000; + return fu ^ mask; +} + +// from "EWA Splatting" (Zwicker et al 2002) eq. 31 +vec3 covariance2D(vec3 world_pos, vec3 cov3d0, vec3 cov3d1) +{ + vec3 view_pos = (u_view_from_world * vec4(world_pos, 1)).xyz; + + view_pos.x = clamp(view_pos.x / view_pos.z, -u_limit, u_limit) * view_pos.z; + view_pos.y = clamp(view_pos.y / view_pos.z, -u_limit, u_limit) * view_pos.z; + + mat3 J = transpose(mat3( + u_focal / view_pos.z, 0, -(u_focal * view_pos.x) / (view_pos.z * view_pos.z), + 0, u_focal / view_pos.z, -(u_focal * view_pos.y) / (view_pos.z * view_pos.z), + 0, 0, 0 + )); + + mat3 T = J * mat3(u_view_from_world); + + mat3 V = mat3( + cov3d0.x, cov3d0.y, cov3d0.z, + cov3d0.y, cov3d1.x, cov3d1.y, + cov3d0.z, cov3d1.y, cov3d1.z + ); + mat3 cov = T * V * transpose(T); + + // Low pass filter to make each splat at least 1px size. + cov[0][0] += 0.3; + cov[1][1] += 0.3; + + return vec3(cov[0][0], cov[1][0], cov[1][1]); +} + +void axisFromCovariance2D(vec3 cov2d, out vec2 v1, out vec2 v2) { + float diag1 = cov2d.x, diag2 = cov2d.z, offDiag = cov2d.y; + + float mid = 0.5 * (diag1 + diag2); + float radius = length(vec2((diag1 - diag2) / 2.0, offDiag)); + float lambda1 = mid + radius; + float lambda2 = max(mid - radius, 0.1); + vec2 diagVec = normalize(vec2(offDiag, lambda1 - diag1)); + float maxSize = 4096.0; + + v1 = min(sqrt(2.0 * lambda1), maxSize) * diagVec; + v2 = min(sqrt(2.0 * lambda2), maxSize) * vec2(diagVec.y, -diagVec.x); +} + +Splat loadSplat(uint index) { + Splat splat; + splat.position.x = InSplatPositions.positions[index * 3 + 0]; + splat.position.y = InSplatPositions.positions[index * 3 + 1]; + splat.position.z = InSplatPositions.positions[index * 3 + 2]; + + SplatData data = InSplatData.data[index]; + splat.color = data.color; + splat.opacity = data.opacity; + splat.scale = data.scale; + splat.rotation = data.rotation; + + return splat; +} + +mat3 matrixFromQuaternionScale(vec4 q, vec3 s) { + mat3 ms = mat3( + s.x, 0, 0, + 0, s.y, 0, + 0, 0, s.z + ); + + float x = q.x; + float y = q.y; + float z = q.z; + float w = q.w; + mat3 mr = transpose(mat3( + 1-2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y), + 2*(x*y + w*z), 1-2*(x*x + z*z), 2*(y*z - w*x), + 2*(x*z - w*y), 2*(y*z + w*x), 1-2*(x*x + y*y) + )); + + return mr * ms; +} + +void main() { + uint thread_idx = gl_GlobalInvocationID.x; + + // Check if block valid. + if(thread_idx >= u_num_splats) { + return; + } + + Splat splat = loadSplat(thread_idx); + + vec3 world_pos = (u_world_from_object * vec4(splat.position, 1.0)).xyz; + vec3 view_pos = (u_view_from_world * vec4(world_pos, 1.0)).xyz; + vec4 clip_pos = u_clip_from_world * vec4(world_pos, 1.0); + + mat3 rotation = mat3(u_world_from_object) * matrixFromQuaternionScale(splat.rotation, splat.scale); + + vec2 v1 = vec2(0.0); + vec2 v2 = vec2(0.0); + if(clip_pos.w > 0) { + mat3 cov_matrix = rotation * transpose(rotation); + vec3 cov3d0 = vec3(cov_matrix[0][0], cov_matrix[0][1], cov_matrix[0][2]) * u_scale2; + vec3 cov3d1 = vec3(cov_matrix[1][1], cov_matrix[1][2], cov_matrix[2][2]) * u_scale2; + + vec3 cov2d = covariance2D(world_pos, cov3d0, cov3d1); + axisFromCovariance2D(cov2d, v1, v2); + + // vec3 world_view_dir = u_camera_pos - world_pos; + // vec3 object_view_diw = u_object_from_world * world_view_dir; + // TODO: SH + } + + SplatView view; + view.position = clip_pos; + view.axis1 = v1; + view.axis2 = v2; + view.color = vec4(splat.color, splat.opacity * u_opacity_scale); + + OutSplatViews.views[thread_idx] = view; + OutSplatDistances.distances[thread_idx] = floatToSortableUint(view_pos.z); + OutSplatIndices.indices[thread_idx] = thread_idx; +} \ No newline at end of file diff --git a/aitviewer/shaders/gaussian_splatting/sort.glsl b/aitviewer/shaders/gaussian_splatting/sort.glsl new file mode 100644 index 0000000..b64ae65 --- /dev/null +++ b/aitviewer/shaders/gaussian_splatting/sort.glsl @@ -0,0 +1,743 @@ +#version 450 +#extension GL_KHR_shader_subgroup_quad : require +#extension GL_ARB_shader_ballot : require +#extension GL_KHR_shader_subgroup_arithmetic : require + +// Copyright (C) 2023 ETH Zurich, Manuel Kaufmann, Velko Vechev, Dario Mylonopoulos + +// Adapted from parallelsort algorithm in FidelityFX-SDK +// https://github.com/GPUOpen-LibrariesAndSDKs/FidelityFX-SDK/tree/main +// +// FidelityFX-SDK License +// +// Copyright (C) 2023 Advanced Micro Devices, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files(the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and /or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions : +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// Entry point defines +#define ENTRY_PARALLEL_SORT_COUNT 0 +#define ENTRY_PARALLEL_SORT_SCAN_REDUCE 0 +#define ENTRY_PARALLEL_SORT_SCAN 0 +#define ENTRY_PARALLEL_SORT_SCAN_ADD 0 +#define ENTRY_PARALLEL_SORT_SCATTER 0 + +// Config defines +#define FFX_PARALLELSORT_COPY_VALUE + +// BINDINGS DEFINES +#define FFX_PARALLELSORT_BIND_UAV_SOURCE_KEYS 0 +#define FFX_PARALLELSORT_BIND_UAV_DEST_KEYS 1 +#define FFX_PARALLELSORT_BIND_UAV_SOURCE_PAYLOADS 2 +#define FFX_PARALLELSORT_BIND_UAV_DEST_PAYLOADS 3 +#define FFX_PARALLELSORT_BIND_UAV_SUM_TABLE 4 +#define FFX_PARALLELSORT_BIND_UAV_REDUCE_TABLE 5 +#define FFX_PARALLELSORT_BIND_UAV_SCAN_SOURCE 6 +#define FFX_PARALLELSORT_BIND_UAV_SCAN_DEST 7 +#define FFX_PARALLELSORT_BIND_UAV_SCAN_SCRATCH 8 + +// --- GLSL Defines +#define FFX_GLSL +#define FFX_GROUPSHARED shared +#define FFX_GROUP_MEMORY_BARRIER() groupMemoryBarrier(); barrier() + +#define FfxUInt32 uint +#define FfxInt32 int + +#define FFX_ATOMIC_ADD(x, y) atomicAdd(x, y) + +// --- Uniform buffer +layout(binding = 0, std140) uniform cbParallelSort_t +{ + FfxUInt32 numKeys; + FfxUInt32 numBlocksPerThreadGroup; + FfxUInt32 numThreadGroups; + FfxUInt32 numThreadGroupsWithAdditionalBlocks; + FfxUInt32 numReduceThreadgroupPerBin; + FfxUInt32 numScanValues; + FfxUInt32 shiftBit; + FfxUInt32 padding; +} u_constants; + +uint FfxNumKeys() { return u_constants.numKeys; } +uint FfxNumBlocksPerThreadGroup() { return u_constants.numBlocksPerThreadGroup; } +uint FfxNumThreadGroups() { return u_constants.numThreadGroups; } +uint FfxNumThreadGroupsWithAdditionalBlocks() { return u_constants.numThreadGroupsWithAdditionalBlocks; } +uint FfxNumReduceThreadgroupPerBin() { return u_constants.numReduceThreadgroupPerBin; } +uint FfxNumScanValues() { return u_constants.numScanValues; } +uint FfxShiftBit() { return u_constants.shiftBit; } + +// --- Buffers +layout(binding = FFX_PARALLELSORT_BIND_UAV_SOURCE_KEYS, std430) coherent buffer ParallelSortSrcKeys_t { uint source_keys[]; } rw_source_keys; +layout(binding = FFX_PARALLELSORT_BIND_UAV_DEST_KEYS, std430) coherent buffer ParallelSortDstKeys_t { uint dest_keys[]; } rw_dest_keys; +layout(binding = FFX_PARALLELSORT_BIND_UAV_SOURCE_PAYLOADS, std430) coherent buffer ParallelSortSrcPayload_t { uint source_payloads[]; } rw_source_payloads; +layout(binding = FFX_PARALLELSORT_BIND_UAV_DEST_PAYLOADS, std430) coherent buffer ParallelSortDstPayload_t { uint dest_payloads[]; } rw_dest_payloads; +layout(binding = FFX_PARALLELSORT_BIND_UAV_SUM_TABLE, std430) coherent buffer ParallelSortSumTable_t { uint sum_table[]; } rw_sum_table; +layout(binding = FFX_PARALLELSORT_BIND_UAV_REDUCE_TABLE, std430) coherent buffer ParallelSortReduceTable_t { uint reduce_table[]; } rw_reduce_table; +layout(binding = FFX_PARALLELSORT_BIND_UAV_SCAN_SOURCE, std430) coherent buffer ParallelSortScanSrc_t { uint scan_source[]; } rw_scan_source; +layout(binding = FFX_PARALLELSORT_BIND_UAV_SCAN_DEST, std430) coherent buffer ParallelSortScanDst_t { uint scan_dest[]; } rw_scan_dest; +layout(binding = FFX_PARALLELSORT_BIND_UAV_SCAN_SCRATCH, std430) coherent buffer ParallelSortScanScratch_t { uint scan_scratch[]; } rw_scan_scratch; + +FfxUInt32 LoadSourceKey(FfxUInt32 index) +{ + return rw_source_keys.source_keys[index]; +} + +void StoreDestKey(FfxUInt32 index, FfxUInt32 value) +{ + rw_dest_keys.dest_keys[index] = value; +} + +FfxUInt32 LoadSourcePayload(FfxUInt32 index) +{ + return rw_source_payloads.source_payloads[index]; +} + +void StoreDestPayload(FfxUInt32 index, FfxUInt32 value) +{ + rw_dest_payloads.dest_payloads[index] = value; +} + +FfxUInt32 LoadSumTable(FfxUInt32 index) +{ + return rw_sum_table.sum_table[index]; +} + +void StoreSumTable(FfxUInt32 index, FfxUInt32 value) +{ + rw_sum_table.sum_table[index] = value; +} + +void StoreReduceTable(FfxUInt32 index, FfxUInt32 value) +{ + rw_reduce_table.reduce_table[index] = value; +} + +FfxUInt32 LoadScanSource(FfxUInt32 index) +{ + return rw_scan_source.scan_source[index]; +} + +void StoreScanDest(FfxUInt32 index, FfxUInt32 value) +{ + rw_scan_dest.scan_dest[index] = value; +} + +FfxUInt32 LoadScanScratch(FfxUInt32 index) +{ + return rw_scan_scratch.scan_scratch[index]; +} + +FfxUInt32 FfxLoadKey(FfxUInt32 index) +{ + return LoadSourceKey(index); +} + +void FfxStoreKey(FfxUInt32 index, FfxUInt32 value) +{ + StoreDestKey(index, value); +} + +FfxUInt32 FfxLoadPayload(FfxUInt32 index) +{ + return LoadSourcePayload(index); +} + +void FfxStorePayload(FfxUInt32 index, FfxUInt32 value) +{ + StoreDestPayload(index, value); +} + +FfxUInt32 FfxLoadSum(FfxUInt32 index) +{ + return LoadSumTable(index); +} + +void FfxStoreSum(FfxUInt32 index, FfxUInt32 value) +{ + StoreSumTable(index, value); +} + +void FfxStoreReduce(FfxUInt32 index, FfxUInt32 value) +{ + StoreReduceTable(index, value); +} + +FfxUInt32 FfxLoadScanSource(FfxUInt32 index) +{ + return LoadScanSource(index); +} + +void FfxStoreScanDest(FfxUInt32 index, FfxUInt32 value) +{ + StoreScanDest(index, value); +} + +FfxUInt32 FfxLoadScanScratch(FfxUInt32 index) +{ + return LoadScanScratch(index); +} + + +// --- Implementation + +/// @defgroup FfxGPUParallelSort FidelityFX Parallel Sort +/// FidelityFX Parallel Sort GPU documentation +/// +/// @ingroup FfxGPUEffects + +/// The number of bits we are sorting per pass. +/// Changing this value requires +/// internal changes in LDS distribution and count, +/// reduce, scan, and scatter passes +/// +/// @ingroup FfxGPUParallelSort +#define FFX_PARALLELSORT_SORT_BITS_PER_PASS 4 + +/// The number of bins used for the counting phase +/// of the algorithm. Changing this value requires +/// internal changes in LDS distribution and count, +/// reduce, scan, and scatter passes +/// +/// @ingroup FfxGPUParallelSort +#define FFX_PARALLELSORT_SORT_BIN_COUNT (1 << FFX_PARALLELSORT_SORT_BITS_PER_PASS) + +/// The number of elements dealt with per running +/// thread +/// +/// @ingroup FfxGPUParallelSort +#define FFX_PARALLELSORT_ELEMENTS_PER_THREAD 4 + +/// The number of threads to execute in parallel +/// for each dispatch group +/// +/// @ingroup FfxGPUParallelSort +#define FFX_PARALLELSORT_THREADGROUP_SIZE 128 + +/// The maximum number of thread groups to run +/// in parallel. Modifying this value can help +/// or hurt GPU occupancy, but is very hardware +/// class specific +/// +/// @ingroup FfxGPUParallelSort +#define FFX_PARALLELSORT_MAX_THREADGROUPS_TO_RUN 800 + +FFX_GROUPSHARED FfxUInt32 gs_FFX_PARALLELSORT_Histogram[FFX_PARALLELSORT_THREADGROUP_SIZE * FFX_PARALLELSORT_SORT_BIN_COUNT]; +void ffxParallelSortCountUInt(FfxUInt32 localID, FfxUInt32 groupID, FfxUInt32 ShiftBit) +{ + // Start by clearing our local counts in LDS + for (FfxInt32 i = 0; i < FFX_PARALLELSORT_SORT_BIN_COUNT; i++) + gs_FFX_PARALLELSORT_Histogram[(i * FFX_PARALLELSORT_THREADGROUP_SIZE) + localID] = 0; + + // Wait for everyone to catch up + FFX_GROUP_MEMORY_BARRIER(); + + // Data is processed in blocks, and how many we process can changed based on how much data we are processing + // versus how many thread groups we are processing with + FfxInt32 BlockSize = FFX_PARALLELSORT_ELEMENTS_PER_THREAD * FFX_PARALLELSORT_THREADGROUP_SIZE; + + // Figure out this thread group's index into the block data (taking into account thread groups that need to do extra reads) + FfxUInt32 NumBlocksPerThreadGroup = FfxNumBlocksPerThreadGroup(); + FfxUInt32 NumThreadGroups = FfxNumThreadGroups(); + FfxUInt32 NumThreadGroupsWithAdditionalBlocks = FfxNumThreadGroupsWithAdditionalBlocks(); + FfxUInt32 NumKeys = FfxNumKeys(); + + FfxUInt32 ThreadgroupBlockStart = (BlockSize * NumBlocksPerThreadGroup * groupID); + FfxUInt32 NumBlocksToProcess = NumBlocksPerThreadGroup; + + if (groupID >= NumThreadGroups - NumThreadGroupsWithAdditionalBlocks) + { + ThreadgroupBlockStart += (groupID - (NumThreadGroups - NumThreadGroupsWithAdditionalBlocks)) * BlockSize; + NumBlocksToProcess++; + } + + // Get the block start index for this thread + FfxUInt32 BlockIndex = ThreadgroupBlockStart + localID; + + // Count value occurrence + for (FfxUInt32 BlockCount = 0; BlockCount < NumBlocksToProcess; BlockCount++, BlockIndex += BlockSize) + { + FfxUInt32 DataIndex = BlockIndex; + + // Pre-load the key values in order to hide some of the read latency + FfxUInt32 srcKeys[FFX_PARALLELSORT_ELEMENTS_PER_THREAD]; + srcKeys[0] = FfxLoadKey(DataIndex); + srcKeys[1] = FfxLoadKey(DataIndex + FFX_PARALLELSORT_THREADGROUP_SIZE); + srcKeys[2] = FfxLoadKey(DataIndex + (FFX_PARALLELSORT_THREADGROUP_SIZE * 2)); + srcKeys[3] = FfxLoadKey(DataIndex + (FFX_PARALLELSORT_THREADGROUP_SIZE * 3)); + + for (FfxUInt32 i = 0; i < FFX_PARALLELSORT_ELEMENTS_PER_THREAD; i++) + { + if (DataIndex < NumKeys) + { + FfxUInt32 localKey = (srcKeys[i] >> ShiftBit) & 0xf; + FFX_ATOMIC_ADD(gs_FFX_PARALLELSORT_Histogram[(localKey * FFX_PARALLELSORT_THREADGROUP_SIZE) + localID], 1); + DataIndex += FFX_PARALLELSORT_THREADGROUP_SIZE; + } + } + } + + // Even though our LDS layout guarantees no collisions, our thread group size is greater than a wave + // so we need to make sure all thread groups are done counting before we start tallying up the results + FFX_GROUP_MEMORY_BARRIER(); + + if (localID < FFX_PARALLELSORT_SORT_BIN_COUNT) + { + FfxUInt32 sum = 0; + for (FfxInt32 i = 0; i < FFX_PARALLELSORT_THREADGROUP_SIZE; i++) + { + sum += gs_FFX_PARALLELSORT_Histogram[localID * FFX_PARALLELSORT_THREADGROUP_SIZE + i]; + } + FfxStoreSum(localID * NumThreadGroups + groupID, sum); + } +} + +FFX_GROUPSHARED FfxUInt32 gs_FFX_PARALLELSORT_LDSSums[FFX_PARALLELSORT_THREADGROUP_SIZE]; +FfxUInt32 ffxParallelSortThreadgroupReduce(FfxUInt32 localSum, FfxUInt32 localID) +{ + // Do wave local reduce +#if defined(FFX_HLSL) + FfxUInt32 waveReduced = WaveActiveSum(localSum); + + // First lane in a wave writes out wave reduction to LDS (this accounts for num waves per group greater than HW wave size) + // Note that some hardware with very small HW wave sizes (i.e. <= 8) may exhibit issues with this algorithm, and have not been tested. + FfxUInt32 waveID = localID / WaveGetLaneCount(); + if (WaveIsFirstLane()) + gs_FFX_PARALLELSORT_LDSSums[waveID] = waveReduced; + + // Wait for everyone to catch up + FFX_GROUP_MEMORY_BARRIER(); + + // First wave worth of threads sum up wave reductions + if (!waveID) + waveReduced = WaveActiveSum((localID < FFX_PARALLELSORT_THREADGROUP_SIZE / WaveGetLaneCount()) ? gs_FFX_PARALLELSORT_LDSSums[localID] : 0); + +#elif defined(FFX_GLSL) + + FfxUInt32 waveReduced = subgroupAdd(localSum); + + // First lane in a wave writes out wave reduction to LDS (this accounts for num waves per group greater than HW wave size) + // Note that some hardware with very small HW wave sizes (i.e. <= 8) may exhibit issues with this algorithm, and have not been tested. + FfxUInt32 waveID = localID / gl_SubGroupSizeARB; + if (subgroupElect()) + gs_FFX_PARALLELSORT_LDSSums[waveID] = waveReduced; + + // Wait for everyone to catch up + FFX_GROUP_MEMORY_BARRIER(); + + // First wave worth of threads sum up wave reductions + if (waveID == 0) + waveReduced = subgroupAdd((localID < FFX_PARALLELSORT_THREADGROUP_SIZE / gl_SubGroupSizeARB) ? gs_FFX_PARALLELSORT_LDSSums[localID] : 0); + +#endif // #if defined(FFX_HLSL) + + // Returned the reduced sum + return waveReduced; +} + +void ffxParallelSortReduceCount(FfxUInt32 localID, FfxUInt32 groupID) +{ + FfxUInt32 NumReduceThreadgroupPerBin = FfxNumReduceThreadgroupPerBin(); + FfxUInt32 NumThreadGroups = FfxNumThreadGroups(); + + // Figure out what bin data we are reducing + FfxUInt32 BinID = groupID / NumReduceThreadgroupPerBin; + FfxUInt32 BinOffset = BinID * NumThreadGroups; + + // Get the base index for this thread group + FfxUInt32 BaseIndex = (groupID % NumReduceThreadgroupPerBin) * FFX_PARALLELSORT_ELEMENTS_PER_THREAD * FFX_PARALLELSORT_THREADGROUP_SIZE; + + // Calculate partial sums for entries this thread reads in + FfxUInt32 threadgroupSum = 0; + for (FfxUInt32 i = 0; i < FFX_PARALLELSORT_ELEMENTS_PER_THREAD; ++i) + { + FfxUInt32 DataIndex = BaseIndex + (i * FFX_PARALLELSORT_THREADGROUP_SIZE) + localID; + threadgroupSum += (DataIndex < NumThreadGroups) ? FfxLoadSum(BinOffset + DataIndex) : 0; + } + + // Reduce across the entirety of the thread group + threadgroupSum = ffxParallelSortThreadgroupReduce(threadgroupSum, localID); + + // First thread of the group writes out the reduced sum for the bin + if (localID == 0) + FfxStoreReduce(groupID, threadgroupSum); + + // What this will look like in the reduced table is: + // [ [bin0 ... bin0] [bin1 ... bin1] ... ] +} + +FfxUInt32 ffxParallelSortBlockScanPrefix(FfxUInt32 localSum, FfxUInt32 localID) +{ +#if defined(FFX_HLSL) + + // Do wave local scan-prefix + FfxUInt32 wavePrefixed = WavePrefixSum(localSum); + + // Since we are dealing with thread group sizes greater than HW wave size, we need to account for what wave we are in. + FfxUInt32 waveID = localID / WaveGetLaneCount(); + FfxUInt32 laneID = WaveGetLaneIndex(); + + // Last element in a wave writes out partial sum to LDS + if (laneID == WaveGetLaneCount() - 1) + gs_FFX_PARALLELSORT_LDSSums[waveID] = wavePrefixed + localSum; + + // Wait for everyone to catch up + FFX_GROUP_MEMORY_BARRIER(); + + // First wave prefixes partial sums + if (!waveID) + gs_FFX_PARALLELSORT_LDSSums[localID] = WavePrefixSum(gs_FFX_PARALLELSORT_LDSSums[localID]); + + // Wait for everyone to catch up + FFX_GROUP_MEMORY_BARRIER(); + + // Add the partial sums back to each wave prefix + wavePrefixed += gs_FFX_PARALLELSORT_LDSSums[waveID]; + +#elif defined(FFX_GLSL) + + // Do wave local scan-prefix + FfxUInt32 wavePrefixed = subgroupExclusiveAdd(localSum); + + // Since we are dealing with thread group sizes greater than HW wave size, we need to account for what wave we are in. + FfxUInt32 waveID = localID / gl_SubGroupSizeARB; + FfxUInt32 laneID = gl_SubGroupInvocationARB; + + // Last element in a wave writes out partial sum to LDS + if (laneID == gl_SubGroupSizeARB - 1) + gs_FFX_PARALLELSORT_LDSSums[waveID] = wavePrefixed + localSum; + + // Wait for everyone to catch up + FFX_GROUP_MEMORY_BARRIER(); + + // First wave prefixes partial sums + if (waveID == 0) + gs_FFX_PARALLELSORT_LDSSums[localID] = subgroupExclusiveAdd(gs_FFX_PARALLELSORT_LDSSums[localID]); + + // Wait for everyone to catch up + FFX_GROUP_MEMORY_BARRIER(); + + // Add the partial sums back to each wave prefix + wavePrefixed += gs_FFX_PARALLELSORT_LDSSums[waveID]; + +#endif // #if defined(FFX_HLSL) + + return wavePrefixed; +} + +// This is to transform uncoalesced loads into coalesced loads and +// then scattered loads from LDS +FFX_GROUPSHARED FfxUInt32 gs_FFX_PARALLELSORT_LDS[FFX_PARALLELSORT_ELEMENTS_PER_THREAD][FFX_PARALLELSORT_THREADGROUP_SIZE]; +void ffxParallelSortScanPrefix(FfxUInt32 numValuesToScan, FfxUInt32 localID, FfxUInt32 groupID, FfxUInt32 BinOffset, FfxUInt32 BaseIndex, bool AddPartialSums) +{ + // Perform coalesced loads into LDS + for (FfxUInt32 i = 0; i < FFX_PARALLELSORT_ELEMENTS_PER_THREAD; i++) + { + FfxUInt32 DataIndex = BaseIndex + (i * FFX_PARALLELSORT_THREADGROUP_SIZE) + localID; + + FfxUInt32 col = ((i * FFX_PARALLELSORT_THREADGROUP_SIZE) + localID) / FFX_PARALLELSORT_ELEMENTS_PER_THREAD; + FfxUInt32 row = ((i * FFX_PARALLELSORT_THREADGROUP_SIZE) + localID) % FFX_PARALLELSORT_ELEMENTS_PER_THREAD; + gs_FFX_PARALLELSORT_LDS[row][col] = (DataIndex < numValuesToScan) ? FfxLoadScanSource(BinOffset + DataIndex) : 0; + } + + // Wait for everyone to catch up + FFX_GROUP_MEMORY_BARRIER(); + + FfxUInt32 threadgroupSum = 0; + // Calculate the local scan-prefix for current thread + for (FfxUInt32 i = 0; i < FFX_PARALLELSORT_ELEMENTS_PER_THREAD; i++) + { + FfxUInt32 tmp = gs_FFX_PARALLELSORT_LDS[i][localID]; + gs_FFX_PARALLELSORT_LDS[i][localID] = threadgroupSum; + threadgroupSum += tmp; + } + + // Scan prefix partial sums + threadgroupSum = ffxParallelSortBlockScanPrefix(threadgroupSum, localID); + + // Add reduced partial sums if requested + FfxUInt32 partialSum = 0; + if (AddPartialSums) + { + // Partial sum additions are a little special as they are tailored to the optimal number of + // thread groups we ran in the beginning, so need to take that into account + partialSum = FfxLoadScanScratch(groupID); + } + + // Add the block scanned-prefixes back in + for (FfxUInt32 i = 0; i < FFX_PARALLELSORT_ELEMENTS_PER_THREAD; i++) + gs_FFX_PARALLELSORT_LDS[i][localID] += threadgroupSum; + + // Wait for everyone to catch up + FFX_GROUP_MEMORY_BARRIER(); + + // Perform coalesced writes to scan dst + for (FfxUInt32 i = 0; i < FFX_PARALLELSORT_ELEMENTS_PER_THREAD; i++) + { + FfxUInt32 DataIndex = BaseIndex + (i * FFX_PARALLELSORT_THREADGROUP_SIZE) + localID; + + FfxUInt32 col = ((i * FFX_PARALLELSORT_THREADGROUP_SIZE) + localID) / FFX_PARALLELSORT_ELEMENTS_PER_THREAD; + FfxUInt32 row = ((i * FFX_PARALLELSORT_THREADGROUP_SIZE) + localID) % FFX_PARALLELSORT_ELEMENTS_PER_THREAD; + + if (DataIndex < numValuesToScan) + FfxStoreScanDest(BinOffset + DataIndex, gs_FFX_PARALLELSORT_LDS[row][col] + partialSum); + } +} + +// Offset cache to avoid loading the offsets all the time +FFX_GROUPSHARED FfxUInt32 gs_FFX_PARALLELSORT_BinOffsetCache[FFX_PARALLELSORT_THREADGROUP_SIZE]; +// Local histogram for offset calculations +FFX_GROUPSHARED FfxUInt32 gs_FFX_PARALLELSORT_LocalHistogram[FFX_PARALLELSORT_SORT_BIN_COUNT]; +// Scratch area for algorithm +FFX_GROUPSHARED FfxUInt32 gs_FFX_PARALLELSORT_LDSScratch[FFX_PARALLELSORT_THREADGROUP_SIZE]; + +void ffxParallelSortScatterUInt(FfxUInt32 localID, FfxUInt32 groupID, FfxUInt32 ShiftBit) +{ + FfxUInt32 NumBlocksPerThreadGroup = FfxNumBlocksPerThreadGroup(); + FfxUInt32 NumThreadGroups = FfxNumThreadGroups(); + FfxUInt32 NumThreadGroupsWithAdditionalBlocks = FfxNumThreadGroupsWithAdditionalBlocks(); + FfxUInt32 NumKeys = FfxNumKeys(); + + // Load the sort bin threadgroup offsets into LDS for faster referencing + if (localID < FFX_PARALLELSORT_SORT_BIN_COUNT) + gs_FFX_PARALLELSORT_BinOffsetCache[localID] = FfxLoadSum(localID * NumThreadGroups + groupID); + + // Wait for everyone to catch up + FFX_GROUP_MEMORY_BARRIER(); + + // Data is processed in blocks, and how many we process can changed based on how much data we are processing + // versus how many thread groups we are processing with + int BlockSize = FFX_PARALLELSORT_ELEMENTS_PER_THREAD * FFX_PARALLELSORT_THREADGROUP_SIZE; + + // Figure out this thread group's index into the block data (taking into account thread groups that need to do extra reads) + FfxUInt32 ThreadgroupBlockStart = (BlockSize * NumBlocksPerThreadGroup * groupID); + FfxUInt32 NumBlocksToProcess = NumBlocksPerThreadGroup; + + if (groupID >= NumThreadGroups - NumThreadGroupsWithAdditionalBlocks) + { + ThreadgroupBlockStart += (groupID - (NumThreadGroups - NumThreadGroupsWithAdditionalBlocks)) * BlockSize; + NumBlocksToProcess++; + } + + // Get the block start index for this thread + FfxUInt32 BlockIndex = ThreadgroupBlockStart + localID; + + // Count value occurences + FfxUInt32 newCount; + for (int BlockCount = 0; BlockCount < NumBlocksToProcess; BlockCount++, BlockIndex += BlockSize) + { + FfxUInt32 DataIndex = BlockIndex; + + // Pre-load the key values in order to hide some of the read latency + FfxUInt32 srcKeys[FFX_PARALLELSORT_ELEMENTS_PER_THREAD]; + srcKeys[0] = FfxLoadKey(DataIndex); + srcKeys[1] = FfxLoadKey(DataIndex + FFX_PARALLELSORT_THREADGROUP_SIZE); + srcKeys[2] = FfxLoadKey(DataIndex + (FFX_PARALLELSORT_THREADGROUP_SIZE * 2)); + srcKeys[3] = FfxLoadKey(DataIndex + (FFX_PARALLELSORT_THREADGROUP_SIZE * 3)); + +#ifdef FFX_PARALLELSORT_COPY_VALUE + FfxUInt32 srcValues[FFX_PARALLELSORT_ELEMENTS_PER_THREAD]; + srcValues[0] = FfxLoadPayload(DataIndex); + srcValues[1] = FfxLoadPayload(DataIndex + FFX_PARALLELSORT_THREADGROUP_SIZE); + srcValues[2] = FfxLoadPayload(DataIndex + (FFX_PARALLELSORT_THREADGROUP_SIZE * 2)); + srcValues[3] = FfxLoadPayload(DataIndex + (FFX_PARALLELSORT_THREADGROUP_SIZE * 3)); +#endif // FFX_PARALLELSORT_COPY_VALUE + + for (int i = 0; i < FFX_PARALLELSORT_ELEMENTS_PER_THREAD; i++) + { + // Clear the local histogram + if (localID < FFX_PARALLELSORT_SORT_BIN_COUNT) + gs_FFX_PARALLELSORT_LocalHistogram[localID] = 0; + + FfxUInt32 localKey = (DataIndex < NumKeys ? srcKeys[i] : 0xffffffff); +#ifdef FFX_PARALLELSORT_COPY_VALUE + FfxUInt32 localValue = (DataIndex < NumKeys ? srcValues[i] : 0); +#endif // FFX_PARALLELSORT_COPY_VALUE + + // Sort the keys locally in LDS + for (FfxUInt32 bitShift = 0; bitShift < FFX_PARALLELSORT_SORT_BITS_PER_PASS; bitShift += 2) + { + // Figure out the keyIndex + FfxUInt32 keyIndex = (localKey >> ShiftBit) & 0xf; + FfxUInt32 bitKey = (keyIndex >> bitShift) & 0x3; + + // Create a packed histogram + FfxUInt32 packedHistogram = 1 << (bitKey * 8); + + // Sum up all the packed keys (generates counted offsets up to current thread group) + FfxUInt32 localSum = ffxParallelSortBlockScanPrefix(packedHistogram, localID); + + // Last thread stores the updated histogram counts for the thread group + // Scratch = 0xsum3|sum2|sum1|sum0 for thread group + if (localID == (FFX_PARALLELSORT_THREADGROUP_SIZE - 1)) + gs_FFX_PARALLELSORT_LDSScratch[0] = localSum + packedHistogram; + + // Wait for everyone to catch up + FFX_GROUP_MEMORY_BARRIER(); + + // Load the sums value for the thread group + packedHistogram = gs_FFX_PARALLELSORT_LDSScratch[0]; + + // Add prefix offsets for all 4 bit "keys" (packedHistogram = 0xsum2_1_0|sum1_0|sum0|0) + packedHistogram = (packedHistogram << 8) + (packedHistogram << 16) + (packedHistogram << 24); + + // Calculate the proper offset for this thread's value + localSum += packedHistogram; + + // Calculate target offset + FfxUInt32 keyOffset = (localSum >> (bitKey * 8)) & 0xff; + + // Re-arrange the keys (store, sync, load) + gs_FFX_PARALLELSORT_LDSSums[keyOffset] = localKey; + FFX_GROUP_MEMORY_BARRIER(); + localKey = gs_FFX_PARALLELSORT_LDSSums[localID]; + + // Wait for everyone to catch up + FFX_GROUP_MEMORY_BARRIER(); + +#ifdef FFX_PARALLELSORT_COPY_VALUE + // Re-arrange the values if we have them (store, sync, load) + gs_FFX_PARALLELSORT_LDSSums[keyOffset] = localValue; + FFX_GROUP_MEMORY_BARRIER(); + localValue = gs_FFX_PARALLELSORT_LDSSums[localID]; + + // Wait for everyone to catch up + FFX_GROUP_MEMORY_BARRIER(); +#endif // FFX_PARALLELSORT_COPY_VALUE + } + + // Need to recalculate the keyIndex on this thread now that values have been copied around the thread group + FfxUInt32 keyIndex = (localKey >> ShiftBit) & 0xf; + + // Reconstruct histogram + FFX_ATOMIC_ADD(gs_FFX_PARALLELSORT_LocalHistogram[keyIndex], 1); + + // Wait for everyone to catch up + FFX_GROUP_MEMORY_BARRIER(); + + // Prefix histogram +#if defined(FFX_HLSL) + FfxUInt32 histogramPrefixSum = WavePrefixSum(localID < FFX_PARALLELSORT_SORT_BIN_COUNT ? gs_FFX_PARALLELSORT_LocalHistogram[localID] : 0); +#elif defined(FFX_GLSL) + FfxUInt32 histogramPrefixSum = subgroupExclusiveAdd(localID < FFX_PARALLELSORT_SORT_BIN_COUNT ? gs_FFX_PARALLELSORT_LocalHistogram[localID] : 0); +#endif // #if defined(FFX_HLSL) + + // Broadcast prefix-sum via LDS + if (localID < FFX_PARALLELSORT_SORT_BIN_COUNT) + gs_FFX_PARALLELSORT_LDSScratch[localID] = histogramPrefixSum; + + // Get the global offset for this key out of the cache + FfxUInt32 globalOffset = gs_FFX_PARALLELSORT_BinOffsetCache[keyIndex]; + + // Wait for everyone to catch up + FFX_GROUP_MEMORY_BARRIER(); + + // Get the local offset (at this point the keys are all in increasing order from 0 -> num bins in localID 0 -> thread group size) + FfxUInt32 localOffset = localID - gs_FFX_PARALLELSORT_LDSScratch[keyIndex]; + + // Write to destination + FfxUInt32 totalOffset = globalOffset + localOffset; + + if (totalOffset < NumKeys) + { + FfxStoreKey(totalOffset, localKey); + +#ifdef FFX_PARALLELSORT_COPY_VALUE + FfxStorePayload(totalOffset, localValue); +#endif // FFX_PARALLELSORT_COPY_VALUE + } + + // Wait for everyone to catch up + FFX_GROUP_MEMORY_BARRIER(); + + // Update the cached histogram for the next set of entries + if (localID < FFX_PARALLELSORT_SORT_BIN_COUNT) + gs_FFX_PARALLELSORT_BinOffsetCache[localID] += gs_FFX_PARALLELSORT_LocalHistogram[localID]; + + DataIndex += FFX_PARALLELSORT_THREADGROUP_SIZE; // Increase the data offset by thread group size + } + } +} + +// --- Entry points + +#if ENTRY_PARALLEL_SORT_SCAN_REDUCE +// Buffers: rw_sum_table, rw_reduce_table +layout (local_size_x = FFX_PARALLELSORT_THREADGROUP_SIZE, local_size_y = 1, local_size_z = 1) in; +void main() +{ + uint LocalID = gl_LocalInvocationID.x; + uint GroupID = gl_WorkGroupID.x; + ffxParallelSortReduceCount(LocalID, GroupID); +} +#endif + +#if ENTRY_PARALLEL_SORT_SCAN_ADD +// Buffers: rw_scan_source, rw_scan_dest, rw_scan_scratch +layout (local_size_x = FFX_PARALLELSORT_THREADGROUP_SIZE, local_size_y = 1, local_size_z = 1) in; +void main() +{ + uint LocalID = gl_LocalInvocationID.x; + uint GroupID = gl_WorkGroupID.x; + // When doing adds, we need to access data differently because reduce + // has a more specialized access pattern to match optimized count + // Access needs to be done similarly to reduce + // Figure out what bin data we are reducing + uint BinID = GroupID / FfxNumReduceThreadgroupPerBin(); + uint BinOffset = BinID * FfxNumThreadGroups(); + + // Get the base index for this thread group + uint BaseIndex = (GroupID % FfxNumReduceThreadgroupPerBin()) * FFX_PARALLELSORT_ELEMENTS_PER_THREAD * FFX_PARALLELSORT_THREADGROUP_SIZE; + + ffxParallelSortScanPrefix(FfxNumThreadGroups(), LocalID, GroupID, BinOffset, BaseIndex, true); +} +#endif + +#if ENTRY_PARALLEL_SORT_SCAN +// Buffers: rw_scan_source, rw_scan_dest +layout (local_size_x = FFX_PARALLELSORT_THREADGROUP_SIZE, local_size_y = 1, local_size_z = 1) in; +void main() +{ + uint LocalID = gl_LocalInvocationID.x; + uint GroupID = gl_WorkGroupID.x; + uint BaseIndex = FFX_PARALLELSORT_ELEMENTS_PER_THREAD * FFX_PARALLELSORT_THREADGROUP_SIZE * GroupID; + ffxParallelSortScanPrefix(FfxNumScanValues(), LocalID, GroupID, 0, BaseIndex, false); +} +#endif + +#if ENTRY_PARALLEL_SORT_SCATTER +// Buffers: rw_source_keys, rw_dest_keys, rw_sum_table, rw_source_payloads, rw_dest_payloads +layout (local_size_x = FFX_PARALLELSORT_THREADGROUP_SIZE, local_size_y = 1, local_size_z = 1) in; +void main() +{ + uint LocalID = gl_LocalInvocationID.x; + uint GroupID = gl_WorkGroupID.x; + ffxParallelSortScatterUInt(LocalID, GroupID, FfxShiftBit()); +} +#endif + +#if ENTRY_PARALLEL_SORT_COUNT +// Buffers: rw_source_keys, rw_sum_table +layout (local_size_x = FFX_PARALLELSORT_THREADGROUP_SIZE, local_size_y = 1, local_size_z = 1) in; +void main() +{ + uint LocalID = gl_LocalInvocationID.x; + uint GroupID = gl_WorkGroupID.x; + ffxParallelSortCountUInt(LocalID, GroupID, FfxShiftBit()); +} +#endif \ No newline at end of file diff --git a/aitviewer/utils/gpu_sort.py b/aitviewer/utils/gpu_sort.py new file mode 100644 index 0000000..f917fb7 --- /dev/null +++ b/aitviewer/utils/gpu_sort.py @@ -0,0 +1,161 @@ +# Copyright (C) 2023 ETH Zurich, Manuel Kaufmann, Velko Vechev, Dario Mylonopoulos +import moderngl +import numpy as np + +from aitviewer.shaders import get_sort_program + +# Adapted from parallelsort algorithm in FidelityFX-SDK +# https://github.com/GPUOpen-LibrariesAndSDKs/FidelityFX-SDK/tree/main +# +# FidelityFX-SDK License +# +# Copyright (C) 2023 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files(the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and /or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions : +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +class GpuSort: + # All these constants must match the respective constants in the sort.glsl shader. + FFX_PARALLELSORT_ELEMENTS_PER_THREAD = 4 + FFX_PARALLELSORT_THREADGROUP_SIZE = 128 + FFX_PARALLELSORT_SORT_BITS_PER_PASS = 4 + FFX_PARALLELSORT_SORT_BIN_COUNT = 1 << FFX_PARALLELSORT_SORT_BITS_PER_PASS + FFX_PARALLELSORT_MAX_THREADGROUPS_TO_RUN = 800 + + FFX_PARALLELSORT_BIND_UAV_SOURCE_KEYS = 0 + FFX_PARALLELSORT_BIND_UAV_DEST_KEYS = 1 + FFX_PARALLELSORT_BIND_UAV_SOURCE_PAYLOADS = 2 + FFX_PARALLELSORT_BIND_UAV_DEST_PAYLOADS = 3 + FFX_PARALLELSORT_BIND_UAV_SUM_TABLE = 4 + FFX_PARALLELSORT_BIND_UAV_REDUCE_TABLE = 5 + FFX_PARALLELSORT_BIND_UAV_SCAN_SOURCE = 6 + FFX_PARALLELSORT_BIND_UAV_SCAN_DEST = 7 + FFX_PARALLELSORT_BIND_UAV_SCAN_SCRATCH = 8 + + def __init__(self, ctx: moderngl.Context, count: int): + self.count = count + + # Create programs. + self.prog_count: moderngl.ComputeShader = get_sort_program("COUNT") + self.prog_scan_reduce: moderngl.ComputeShader = get_sort_program("SCAN_REDUCE") + self.prog_scan: moderngl.ComputeShader = get_sort_program("SCAN") + self.prog_scan_add: moderngl.ComputeShader = get_sort_program("SCAN_ADD") + self.prog_scatter: moderngl.ComputeShader = get_sort_program("SCATTER") + + def div_round_up(n, d): + return (n + d - 1) // d + + # Buffers. + block_size = self.FFX_PARALLELSORT_ELEMENTS_PER_THREAD * self.FFX_PARALLELSORT_THREADGROUP_SIZE + num_blocks = div_round_up(count, block_size) + num_reduced_blocks = div_round_up(num_blocks, block_size) + + scratch_buffer_size = self.FFX_PARALLELSORT_SORT_BIN_COUNT * num_blocks + reduce_scratch_buffer_size = self.FFX_PARALLELSORT_SORT_BIN_COUNT * num_reduced_blocks + + self.sort_scratch_buf = ctx.buffer(reserve=count * 4) + self.payload_scratch_buf = ctx.buffer(reserve=count * 4) + self.scratch_buf = ctx.buffer(reserve=scratch_buffer_size * 4) + self.reduced_scratch_buf = ctx.buffer(reserve=reduce_scratch_buffer_size * 4) + + # Constants. + num_thread_groups_to_run = self.FFX_PARALLELSORT_MAX_THREADGROUPS_TO_RUN + blocks_per_thread_group = num_blocks // num_thread_groups_to_run + num_thread_groups_with_additional_blocks = num_blocks % num_thread_groups_to_run + + if num_blocks < num_thread_groups_to_run: + blocks_per_thread_group = 1 + num_thread_groups_to_run = num_blocks + num_thread_groups_with_additional_blocks = 0 + + num_reduce_thread_groups_to_run = self.FFX_PARALLELSORT_SORT_BIN_COUNT * ( + 1 if block_size > num_thread_groups_to_run else div_round_up(num_thread_groups_to_run, block_size) + ) + num_reduce_thread_groups_per_bin = num_reduce_thread_groups_to_run // self.FFX_PARALLELSORT_SORT_BIN_COUNT + num_scan_values = num_reduce_thread_groups_to_run + + constants = np.array( + [ + count, + blocks_per_thread_group, + num_thread_groups_to_run, + num_thread_groups_with_additional_blocks, + num_reduce_thread_groups_per_bin, + num_scan_values, + 0, + 0, + ], + np.uint32, + ) + self.constants_buf = ctx.buffer(constants.tobytes()) + + self.num_thread_groups_to_run = num_thread_groups_to_run + self.num_reduce_thread_groups_to_run = num_reduce_thread_groups_to_run + + def run(self, ctx: moderngl.Context, keys: moderngl.Buffer, values: moderngl.Buffer): + shift_data = np.array([0], np.uint32) + + src_keys = keys + src_payload = values + dst_keys = self.sort_scratch_buf + dst_payload = self.payload_scratch_buf + + for shift in range(0, 32, self.FFX_PARALLELSORT_SORT_BITS_PER_PASS): + shift_data[0] = shift + self.constants_buf.write(shift_data.tobytes(), offset=6 * 4) + self.constants_buf.bind_to_uniform_block(0) + + src_keys.bind_to_storage_buffer(self.FFX_PARALLELSORT_BIND_UAV_SOURCE_KEYS) + self.scratch_buf.bind_to_storage_buffer(self.FFX_PARALLELSORT_BIND_UAV_SUM_TABLE) + self.prog_count.run(self.num_thread_groups_to_run) + ctx.memory_barrier() + + self.scratch_buf.bind_to_storage_buffer(self.FFX_PARALLELSORT_BIND_UAV_SUM_TABLE) + self.reduced_scratch_buf.bind_to_storage_buffer(self.FFX_PARALLELSORT_BIND_UAV_REDUCE_TABLE) + self.prog_scan_reduce.run(self.num_reduce_thread_groups_to_run) + ctx.memory_barrier() + + self.reduced_scratch_buf.bind_to_storage_buffer(self.FFX_PARALLELSORT_BIND_UAV_SCAN_SOURCE) + self.reduced_scratch_buf.bind_to_storage_buffer(self.FFX_PARALLELSORT_BIND_UAV_SCAN_DEST) + self.prog_scan.run(1) + ctx.memory_barrier() + + self.scratch_buf.bind_to_storage_buffer(self.FFX_PARALLELSORT_BIND_UAV_SCAN_SOURCE) + self.scratch_buf.bind_to_storage_buffer(self.FFX_PARALLELSORT_BIND_UAV_SCAN_DEST) + self.reduced_scratch_buf.bind_to_storage_buffer(self.FFX_PARALLELSORT_BIND_UAV_SCAN_SCRATCH) + self.prog_scan_add.run(self.num_reduce_thread_groups_to_run) + ctx.memory_barrier() + + src_keys.bind_to_storage_buffer(self.FFX_PARALLELSORT_BIND_UAV_SOURCE_KEYS) + dst_keys.bind_to_storage_buffer(self.FFX_PARALLELSORT_BIND_UAV_DEST_KEYS) + src_payload.bind_to_storage_buffer(self.FFX_PARALLELSORT_BIND_UAV_SOURCE_PAYLOADS) + dst_payload.bind_to_storage_buffer(self.FFX_PARALLELSORT_BIND_UAV_DEST_PAYLOADS) + self.scratch_buf.bind_to_storage_buffer(self.FFX_PARALLELSORT_BIND_UAV_SUM_TABLE) + self.prog_scatter.run(self.num_thread_groups_to_run) + ctx.memory_barrier() + + src_keys, dst_keys = dst_keys, src_keys + src_payload, dst_payload = dst_payload, src_payload + + def release(self): + self.sort_scratch_buf.release() + self.payload_scratch_buf.release() + self.scratch_buf.release() + self.reduced_scratch_buf.release() + self.constants_buf.release() diff --git a/examples/gaussian_splatting.py b/examples/gaussian_splatting.py new file mode 100644 index 0000000..2c07106 --- /dev/null +++ b/examples/gaussian_splatting.py @@ -0,0 +1,104 @@ +import json +import os + +import imgui +import numpy as np + +from aitviewer.renderables.gaussian_splats import GaussianSplats +from aitviewer.scene.camera import OpenCVCamera +from aitviewer.scene.node import Node +from aitviewer.viewer import Viewer + +# Update this variable to point to the Gaussian Splatting dataset that can be downloaded +# from here https://github.com/graphdeco-inria/gaussian-splatting link is "Pre-trained Models (14 GB)"". +# +# This variable should point to the top level directory containing a directory for each scene. +PATH = "" + +if not PATH: + print( + "Update this variable to point to the Gaussian Splatting dataset that can be downloaded" + ' from here https://github.com/graphdeco-inria/gaussian-splatting clicking on "Pre-trained Models"' + ) + exit(1) + +dataset = {f: os.path.join(PATH, f) for f in sorted(os.listdir(PATH))} + + +gs = None +cameras = Node("Cameras") + + +def set_scene(viewer, name, iteration): + global gs, cameras + if gs is not None: + viewer.scene.remove(gs) + for c in cameras.nodes: + cameras.remove(c) + + path = os.path.join(dataset[name], "point_cloud", f"iteration_{iteration}", "point_cloud.ply") + to_y_up = np.array([[-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], np.float32) + gs = GaussianSplats.from_ply(path, rotation=to_y_up[:3, :3]) + gs._debug_gui = True + print(f"Loaded {gs.num_splats} splats") + + cams = json.load(open(os.path.join(dataset[name], "cameras.json"))) + for c in cams: + t = np.array(c["position"]) + R = np.array(c["rotation"]).reshape(3, 3) + fx = c["fx"] + fy = c["fy"] + w = c["width"] + h = c["height"] + name = str(c["id"]) + + Rt = np.eye(4) + Rt[:3, :3] = R + Rt[:3, 3] = t + Rt = (np.linalg.inv(Rt) @ to_y_up)[:3, :4] + + K = np.array( + [ + [fx, 0, w / 2], + [0, fy, h / 2], + [0, 0, 1], + ] + ) + + c = OpenCVCamera(K, Rt, w, h, name=name, viewer=viewer) + cameras.add(c) + + viewer.scene.add(gs) + + +v = Viewer(size=(1600, 900)) +v.auto_set_floor = False +v.auto_set_camera_target = False +v.scene.floor.enabled = False +v.scene.background_color = (0, 0, 0, 1) +v.scene.add(cameras, enabled=False) + + +def gui_dataset(): + imgui.set_next_window_position(v.window_size[0] - 200, 50, imgui.FIRST_USE_EVER) + imgui.set_next_window_size(v.window_size[0] * 0.2, v.window_size[1] * 0.5, imgui.FIRST_USE_EVER) + expanded, _ = imgui.begin("Dataset", None) + if expanded: + for i, k in enumerate(dataset.keys()): + space = imgui.get_content_region_available()[0] + imgui.text(k) + imgui.same_line() + imgui.set_cursor_pos_x(space - 150) + if imgui.button(f"7000##{i}", width=70): + set_scene(v, k, 7000) + imgui.same_line() + if imgui.button(f"30000##{i}", width=70): + set_scene(v, k, 30000) + imgui.end() + + +v.gui_controls["gs_dataset"] = gui_dataset + +set_scene(v, "bicycle", 30000) + +v.run()