You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I wrote a draft to serialize a compressed Gaussian splatting in Python. Unfortunately, it does not work correctly. When loading the values into supersplat, I often see NaNs or extremely high/low values.
I don’t have much time right now to work on it, but it would be great if someone could take a look at my code and use it as a starting point to get it working properly.
Below, I’ve attached the snippet and a splat to run it. I hope someone can improve it and make it work.
Thanks in advance!
import torch
import math
import numpy as np
import struct
from io import BytesIO
from plyfile import PlyData
def sh2rgb(sh: torch.Tensor) -> torch.Tensor:
C0 = 0.28209479177387814
return sh * C0 + 0.5
def part1by2_vec(x: torch.Tensor) -> torch.Tensor:
x = x & 0x000003FF
x = (x ^ (x << 16)) & 0xFF0000FF
x = (x ^ (x << 8)) & 0x0300F00F
x = (x ^ (x << 4)) & 0x030C30C3
x = (x ^ (x << 2)) & 0x09249249
return x
def encode_morton3_vec(
x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
) -> torch.Tensor:
return (part1by2_vec(z) << 2) + (part1by2_vec(y) << 1) + part1by2_vec(x)
def sort_centers(centers: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
# Compute min and max values in a single operation
min_vals, _ = torch.min(centers, dim=0)
max_vals, _ = torch.max(centers, dim=0)
# Compute the scaling factors
lengths = max_vals - min_vals
lengths[lengths == 0] = 1 # Prevent division by zero
# Normalize and scale to 10-bit integer range (0-1024)
scaled_centers = ((centers - min_vals) / lengths * 1024).floor().to(torch.int32)
# Extract x, y, z coordinates
x, y, z = scaled_centers[:, 0], scaled_centers[:, 1], scaled_centers[:, 2]
# Compute Morton codes using vectorized operations
morton = encode_morton3_vec(x, y, z)
# Sort indices based on Morton codes
sorted_indices = indices[torch.argsort(morton)]
return sorted_indices
def pack_unorm(value: float, bits: int) -> int:
t = (1 << bits) - 1
return max(0, min(t, math.floor(value * t)))
def pack_111011(x: float, y: float, z: float) -> int:
return (pack_unorm(x, 11) << 21) | (pack_unorm(y, 10) << 11) | (pack_unorm(z, 11))
def pack_8888(x: float, y: float, z: float, w: float) -> int:
return (
(pack_unorm(x, 8) << 24)
| (pack_unorm(y, 8) << 16)
| (pack_unorm(z, 8) << 8)
| (pack_unorm(w, 8))
)
def pack_rot(x: float, y: float, z: float, w: float) -> int:
q = np.array([x, y, z, w])
q /= np.linalg.norm(q) + 1e-8 # Normalize the quaternion
largest = np.argmax(np.abs(q)) # Find the index of the largest component
if q[largest] < 0: # Ensure positivity for consistency
q = -q
norm = np.sqrt(2) * 0.5
result = largest
for i in range(4):
if i != largest:
result = (result << 10) | pack_unorm(q[i] * norm + 0.5, 10)
return result
def ply_bytes_compressed(
means: torch.Tensor, # Shape (N, 3)
scales: torch.Tensor, # Shape (N, 3)
quats: torch.Tensor, # Shape (N, 4)
sh0: torch.Tensor, # Shape (N, 3)
shN: torch.Tensor, # Shape (N, K)
opacities: torch.Tensor, # Shape (N,)
) -> bytes:
chunk_max_size = 256
sh0_colors = sh2rgb(sh0)
mask = torch.sigmoid(opacities) > (1 / 255)
means = means[mask]
scales = scales[mask]
sh0_colors = sh0_colors[mask]
shN = shN[mask]
quats = quats[mask]
opacities = opacities[mask]
num_splats = means.shape[0]
n_chunks = num_splats // chunk_max_size + (num_splats % chunk_max_size != 0)
indices = torch.arange(num_splats)
indices = sort_centers(means, indices)
float_properties = [
"min_x",
"min_y",
"min_z",
"max_x",
"max_y",
"max_z",
"min_scale_x",
"min_scale_y",
"min_scale_z",
"max_scale_x",
"max_scale_y",
"max_scale_z",
"min_r",
"min_g",
"min_b",
"max_r",
"max_g",
"max_b",
]
uint_properties = [
"packed_position",
"packed_rotation",
"packed_scale",
"packed_color",
]
buffer = BytesIO()
# Write PLY header
buffer.write(b"ply\n")
buffer.write(b"format binary_little_endian 1.0\n")
buffer.write(f"element chunk {n_chunks}\n".encode())
for prop in float_properties:
buffer.write(f"property float {prop}\n".encode())
buffer.write(f"element vertex {num_splats}\n".encode())
for prop in uint_properties:
buffer.write(f"property uint {prop}\n".encode())
buffer.write(f"element sh {num_splats}\n".encode())
for j in range(shN.shape[1]):
buffer.write(f"property uchar f_rest_{j}\n".encode())
buffer.write(b"end_header\n")
for chunk_idx in range(n_chunks):
chunk_end_idx = min((chunk_idx + 1) * chunk_max_size, num_splats)
chunk_size = chunk_end_idx - chunk_idx * chunk_max_size
chunk_start_idx = chunk_idx * chunk_max_size
splat_idxs = indices[chunk_start_idx:chunk_end_idx]
if chunk_size < chunk_max_size:
# Take elements from the previous chunk if the current one is not full
missing_idxs = chunk_max_size - chunk_size
additional_idxs = indices[chunk_start_idx - missing_idxs : chunk_start_idx]
splat_idxs = torch.cat([splat_idxs, additional_idxs])
# Write the bounds in the chunk
# Means
min_means_bounds = torch.min(means[splat_idxs], dim=0).values
max_means_bounds = torch.max(means[splat_idxs], dim=0).values
buffer.write(struct.pack("<3f", *min_means_bounds.tolist()))
buffer.write(struct.pack("<3f", *max_means_bounds.tolist()))
# Scales
min_scales_bounds = torch.min(scales[splat_idxs], dim=0).values
max_scales_bounds = torch.max(scales[splat_idxs], dim=0).values
min_scales_bounds = torch.clamp(min_scales_bounds, -20, 20)
max_scales_bounds = torch.clamp(max_scales_bounds, -20, 20)
buffer.write(struct.pack("<3f", *min_scales_bounds.tolist()))
buffer.write(struct.pack("<3f", *max_scales_bounds.tolist()))
# Colors
min_colors_bounds = torch.min(sh0_colors[splat_idxs], dim=0).values
max_colors_bounds = torch.max(sh0_colors[splat_idxs], dim=0).values
buffer.write(struct.pack("<3f", *min_colors_bounds.tolist()))
buffer.write(struct.pack("<3f", *max_colors_bounds.tolist()))
# Write the quantized remaining properties
normalized_means = (means[splat_idxs] - min_means_bounds) / (
max_means_bounds - min_means_bounds
)
normalized_scales = (scales[splat_idxs] - min_scales_bounds) / (
max_scales_bounds - min_scales_bounds
)
normalized_colors = (sh0_colors[splat_idxs] - min_colors_bounds) / (
max_colors_bounds - min_colors_bounds
)
chunk_quats = quats[splat_idxs]
chunk_opacities = opacities[splat_idxs]
for i in range(chunk_max_size):
# Means
means_i = pack_111011(*normalized_means[i].tolist())
# Quaternions
quat_i = pack_rot(*chunk_quats[i].tolist())
# Scales
scales_i = pack_111011(*normalized_scales[i].tolist())
# Colors
normalized_colors_i = normalized_colors[i].tolist()
opacity = 1 / (1 + torch.exp(-chunk_opacities[i])).item()
normalized_colors_i.append(opacity)
color_i = pack_8888(*normalized_colors_i)
# Print the values class
buffer.write(struct.pack("<I", means_i))
buffer.write(struct.pack("<I", quat_i))
buffer.write(struct.pack("<I", scales_i))
buffer.write(struct.pack("<I", color_i))
# Write quantized spherical harmonics SH
shN_chunk = shN[splat_idxs]
shN_chunk_quantized = (shN_chunk / 8 + 0.5) * 256
for value in shN_chunk_quantized.ravel():
value = max(0, min(255, math.trunc(value.item())))
buffer.write(struct.pack("B", value))
return buffer.getvalue()
ply_path = "model.ply"
plydata = PlyData.read(ply_path)
vert = plydata["vertex"]
# First means
means = torch.tensor([vert["x"], vert["y"], vert["z"]]).T
quats = torch.tensor([vert["rot_0"], vert["rot_1"], vert["rot_2"], vert["rot_3"]]).T
sh_0 = torch.tensor([vert["f_dc_0"], vert["f_dc_1"], vert["f_dc_2"]]).T
sh_n = torch.ones((means.shape[0], 45))
for i in range(45):
sh_n[:, i] = torch.tensor(vert[f"f_rest_{i}"])
opacities = torch.tensor(vert["opacity"])
scale = torch.tensor([vert["scale_0"], vert["scale_1"], vert["scale_2"]]).T
ply_bytes_compressed = ply_bytes_compressed(means, scale, quats, sh_0, sh_n, opacities)
with open("model_compressed.ply", "wb") as f:
f.write(ply_bytes_compressed)
No, I didn't know about the standalone project—thanks for sharing it with me! I'll take a look.
I started writing the Python compressor because I'm working with Gaussian splatting in a personal project. Since I already have a method to save it as a PLY file, I wanted to see if it was also possible to save it as a compressed PLY.
Hi everyone,
I wrote a draft to serialize a compressed Gaussian splatting in Python. Unfortunately, it does not work correctly. When loading the values into supersplat, I often see NaNs or extremely high/low values.
I don’t have much time right now to work on it, but it would be great if someone could take a look at my code and use it as a starting point to get it working properly.
Below, I’ve attached the snippet and a splat to run it. I hope someone can improve it and make it work.
Thanks in advance!
Link to example ply splat
The text was updated successfully, but these errors were encountered: