Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize Compressed Splat in Python #419

Open
bolopenguin opened this issue Feb 25, 2025 · 2 comments
Open

Serialize Compressed Splat in Python #419

bolopenguin opened this issue Feb 25, 2025 · 2 comments

Comments

@bolopenguin
Copy link

Hi everyone,

I wrote a draft to serialize a compressed Gaussian splatting in Python. Unfortunately, it does not work correctly. When loading the values into supersplat, I often see NaNs or extremely high/low values.

I don’t have much time right now to work on it, but it would be great if someone could take a look at my code and use it as a starting point to get it working properly.

Below, I’ve attached the snippet and a splat to run it. I hope someone can improve it and make it work.

Thanks in advance!

import torch
import math
import numpy as np
import struct
from io import BytesIO
from plyfile import PlyData


def sh2rgb(sh: torch.Tensor) -> torch.Tensor:
    C0 = 0.28209479177387814
    return sh * C0 + 0.5


def part1by2_vec(x: torch.Tensor) -> torch.Tensor:
    x = x & 0x000003FF
    x = (x ^ (x << 16)) & 0xFF0000FF
    x = (x ^ (x << 8)) & 0x0300F00F
    x = (x ^ (x << 4)) & 0x030C30C3
    x = (x ^ (x << 2)) & 0x09249249
    return x


def encode_morton3_vec(
    x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
) -> torch.Tensor:
    return (part1by2_vec(z) << 2) + (part1by2_vec(y) << 1) + part1by2_vec(x)


def sort_centers(centers: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
    # Compute min and max values in a single operation
    min_vals, _ = torch.min(centers, dim=0)
    max_vals, _ = torch.max(centers, dim=0)

    # Compute the scaling factors
    lengths = max_vals - min_vals
    lengths[lengths == 0] = 1  # Prevent division by zero

    # Normalize and scale to 10-bit integer range (0-1024)
    scaled_centers = ((centers - min_vals) / lengths * 1024).floor().to(torch.int32)

    # Extract x, y, z coordinates
    x, y, z = scaled_centers[:, 0], scaled_centers[:, 1], scaled_centers[:, 2]

    # Compute Morton codes using vectorized operations
    morton = encode_morton3_vec(x, y, z)

    # Sort indices based on Morton codes
    sorted_indices = indices[torch.argsort(morton)]

    return sorted_indices


def pack_unorm(value: float, bits: int) -> int:
    t = (1 << bits) - 1
    return max(0, min(t, math.floor(value * t)))


def pack_111011(x: float, y: float, z: float) -> int:
    return (pack_unorm(x, 11) << 21) | (pack_unorm(y, 10) << 11) | (pack_unorm(z, 11))


def pack_8888(x: float, y: float, z: float, w: float) -> int:
    return (
        (pack_unorm(x, 8) << 24)
        | (pack_unorm(y, 8) << 16)
        | (pack_unorm(z, 8) << 8)
        | (pack_unorm(w, 8))
    )


def pack_rot(x: float, y: float, z: float, w: float) -> int:
    q = np.array([x, y, z, w])
    q /= np.linalg.norm(q) + 1e-8  # Normalize the quaternion

    largest = np.argmax(np.abs(q))  # Find the index of the largest component

    if q[largest] < 0:  # Ensure positivity for consistency
        q = -q

    norm = np.sqrt(2) * 0.5
    result = largest

    for i in range(4):
        if i != largest:
            result = (result << 10) | pack_unorm(q[i] * norm + 0.5, 10)

    return result


def ply_bytes_compressed(
    means: torch.Tensor,  # Shape (N, 3)
    scales: torch.Tensor,  # Shape (N, 3)
    quats: torch.Tensor,  # Shape (N, 4)
    sh0: torch.Tensor,  # Shape (N, 3)
    shN: torch.Tensor,  # Shape (N, K)
    opacities: torch.Tensor,  # Shape (N,)
) -> bytes:
    chunk_max_size = 256
    sh0_colors = sh2rgb(sh0)

    mask = torch.sigmoid(opacities) > (1 / 255)
    means = means[mask]
    scales = scales[mask]
    sh0_colors = sh0_colors[mask]
    shN = shN[mask]
    quats = quats[mask]
    opacities = opacities[mask]

    num_splats = means.shape[0]
    n_chunks = num_splats // chunk_max_size + (num_splats % chunk_max_size != 0)
    indices = torch.arange(num_splats)
    indices = sort_centers(means, indices)

    float_properties = [
        "min_x",
        "min_y",
        "min_z",
        "max_x",
        "max_y",
        "max_z",
        "min_scale_x",
        "min_scale_y",
        "min_scale_z",
        "max_scale_x",
        "max_scale_y",
        "max_scale_z",
        "min_r",
        "min_g",
        "min_b",
        "max_r",
        "max_g",
        "max_b",
    ]
    uint_properties = [
        "packed_position",
        "packed_rotation",
        "packed_scale",
        "packed_color",
    ]
    buffer = BytesIO()

    # Write PLY header
    buffer.write(b"ply\n")
    buffer.write(b"format binary_little_endian 1.0\n")
    buffer.write(f"element chunk {n_chunks}\n".encode())
    for prop in float_properties:
        buffer.write(f"property float {prop}\n".encode())
    buffer.write(f"element vertex {num_splats}\n".encode())
    for prop in uint_properties:
        buffer.write(f"property uint {prop}\n".encode())
    buffer.write(f"element sh {num_splats}\n".encode())
    for j in range(shN.shape[1]):
        buffer.write(f"property uchar f_rest_{j}\n".encode())
    buffer.write(b"end_header\n")

    for chunk_idx in range(n_chunks):
        chunk_end_idx = min((chunk_idx + 1) * chunk_max_size, num_splats)
        chunk_size = chunk_end_idx - chunk_idx * chunk_max_size
        chunk_start_idx = chunk_idx * chunk_max_size
        splat_idxs = indices[chunk_start_idx:chunk_end_idx]
        if chunk_size < chunk_max_size:
            # Take elements from the previous chunk if the current one is not full
            missing_idxs = chunk_max_size - chunk_size
            additional_idxs = indices[chunk_start_idx - missing_idxs : chunk_start_idx]
            splat_idxs = torch.cat([splat_idxs, additional_idxs])

        # Write the bounds in the chunk
        # Means
        min_means_bounds = torch.min(means[splat_idxs], dim=0).values
        max_means_bounds = torch.max(means[splat_idxs], dim=0).values
        buffer.write(struct.pack("<3f", *min_means_bounds.tolist()))
        buffer.write(struct.pack("<3f", *max_means_bounds.tolist()))
        # Scales
        min_scales_bounds = torch.min(scales[splat_idxs], dim=0).values
        max_scales_bounds = torch.max(scales[splat_idxs], dim=0).values
        min_scales_bounds = torch.clamp(min_scales_bounds, -20, 20)
        max_scales_bounds = torch.clamp(max_scales_bounds, -20, 20)
        buffer.write(struct.pack("<3f", *min_scales_bounds.tolist()))
        buffer.write(struct.pack("<3f", *max_scales_bounds.tolist()))
        # Colors
        min_colors_bounds = torch.min(sh0_colors[splat_idxs], dim=0).values
        max_colors_bounds = torch.max(sh0_colors[splat_idxs], dim=0).values
        buffer.write(struct.pack("<3f", *min_colors_bounds.tolist()))
        buffer.write(struct.pack("<3f", *max_colors_bounds.tolist()))

        # Write the quantized remaining properties
        normalized_means = (means[splat_idxs] - min_means_bounds) / (
            max_means_bounds - min_means_bounds
        )
        normalized_scales = (scales[splat_idxs] - min_scales_bounds) / (
            max_scales_bounds - min_scales_bounds
        )
        normalized_colors = (sh0_colors[splat_idxs] - min_colors_bounds) / (
            max_colors_bounds - min_colors_bounds
        )

        chunk_quats = quats[splat_idxs]
        chunk_opacities = opacities[splat_idxs]
        for i in range(chunk_max_size):
            # Means
            means_i = pack_111011(*normalized_means[i].tolist())
            # Quaternions
            quat_i = pack_rot(*chunk_quats[i].tolist())
            # Scales
            scales_i = pack_111011(*normalized_scales[i].tolist())
            # Colors
            normalized_colors_i = normalized_colors[i].tolist()
            opacity = 1 / (1 + torch.exp(-chunk_opacities[i])).item()
            normalized_colors_i.append(opacity)
            color_i = pack_8888(*normalized_colors_i)
            # Print the values class
            buffer.write(struct.pack("<I", means_i))
            buffer.write(struct.pack("<I", quat_i))
            buffer.write(struct.pack("<I", scales_i))
            buffer.write(struct.pack("<I", color_i))

        # Write quantized spherical harmonics SH
        shN_chunk = shN[splat_idxs]
        shN_chunk_quantized = (shN_chunk / 8 + 0.5) * 256
        for value in shN_chunk_quantized.ravel():
            value = max(0, min(255, math.trunc(value.item())))
            buffer.write(struct.pack("B", value))

    return buffer.getvalue()


ply_path = "model.ply"
plydata = PlyData.read(ply_path)
vert = plydata["vertex"]

# First means
means = torch.tensor([vert["x"], vert["y"], vert["z"]]).T
quats = torch.tensor([vert["rot_0"], vert["rot_1"], vert["rot_2"], vert["rot_3"]]).T
sh_0 = torch.tensor([vert["f_dc_0"], vert["f_dc_1"], vert["f_dc_2"]]).T
sh_n = torch.ones((means.shape[0], 45))
for i in range(45):
    sh_n[:, i] = torch.tensor(vert[f"f_rest_{i}"])
opacities = torch.tensor(vert["opacity"])
scale = torch.tensor([vert["scale_0"], vert["scale_1"], vert["scale_2"]]).T

ply_bytes_compressed = ply_bytes_compressed(means, scale, quats, sh_0, sh_n, opacities)

with open("model_compressed.ply", "wb") as f:
    f.write(ply_bytes_compressed)

Link to example ply splat

@slimbuck
Copy link
Member

Hi @bolopenguin ,

This looks cool!

Would you mind giving some context as to why you need a python compressor specifically?

Do you know about https://github.com/playcanvas/splat-transform? It's a stand-alone js-based compressor tool.

Thanks!

@bolopenguin
Copy link
Author

No, I didn't know about the standalone project—thanks for sharing it with me! I'll take a look.

I started writing the Python compressor because I'm working with Gaussian splatting in a personal project. Since I already have a method to save it as a PLY file, I wanted to see if it was also possible to save it as a compressed PLY.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants