Skip to content

Commit

Permalink
add fp16 support for fps_cuda and ball_query_cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
eamonn-zh committed Dec 23, 2024
1 parent 055ab3a commit 8831615
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 67 deletions.
14 changes: 7 additions & 7 deletions pytorch3d/csrc/ball_query/ball_query.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ __global__ void BallQueryKernel(
at::PackedTensorAccessor64<int64_t, 3, at::RestrictPtrTraits> idxs,
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> dists,
const int64_t K,
const float radius2) {
const scalar_t radius2) {
const int64_t N = p1.size(0);
const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
Expand Down Expand Up @@ -95,7 +95,7 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
const int N = p1.size(0);
const int P1 = p1.size(1);
const int64_t K_64 = K;
const float radius2 = radius * radius;
const auto radius2 = radius * radius;

// Output tensor with indices of neighbors for each point in p1
auto long_dtype = lengths1.options().dtype(at::kLong);
Expand All @@ -110,15 +110,15 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
const size_t blocks = 256;
const size_t threads = 256;

AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
p1.scalar_type(), "ball_query_kernel_cuda", ([&] {
BallQueryKernel<<<blocks, threads, 0, stream>>>(
p1.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
p2.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
BallQueryKernel<scalar_t><<<blocks, threads, 0, stream>>>(
p1.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(),
p2.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(),
lengths1.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(),
lengths2.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>(),
idxs.packed_accessor64<int64_t, 3, at::RestrictPtrTraits>(),
dists.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
dists.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(),
K_64,
radius2);
}));
Expand Down
165 changes: 118 additions & 47 deletions pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@
#include <cub/cub.cuh>
#include "utils/warp_reduce.cuh"

template <unsigned int block_size>
template <unsigned int block_size, typename scalar_t>
__global__ void FarthestPointSamplingKernel(
// clang-format off
const at::PackedTensorAccessor64<float, 3, at::RestrictPtrTraits> points,
const at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> points,
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> lengths,
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> K,
at::PackedTensorAccessor64<int64_t, 2, at::RestrictPtrTraits> idxs,
at::PackedTensorAccessor64<float, 2, at::RestrictPtrTraits> min_point_dist,
at::PackedTensorAccessor64<scalar_t, 2, at::RestrictPtrTraits> min_point_dist,
const at::PackedTensorAccessor64<int64_t, 1, at::RestrictPtrTraits> start_idxs
// clang-format on
) {
typedef cub::BlockReduce<
cub::KeyValuePair<int64_t, float>,
cub::KeyValuePair<int64_t, scalar_t>,
block_size,
cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY>
BlockReduce;
Expand Down Expand Up @@ -57,24 +57,24 @@ __global__ void FarthestPointSamplingKernel(
// Keep track of the maximum of the minimum distance to previously selected
// points seen by this thread
int64_t max_dist_idx = 0;
float max_dist = -1.0;
scalar_t max_dist = -1.0;

// Iterate through all the points in this pointcloud. For already selected
// points, the minimum distance to the set of previously selected points
// will be 0.0 so they won't be selected again.
for (int64_t p = tid; p < lengths[batch_idx]; p += block_size) {
// Calculate the distance to the last selected point
float dist2 = 0.0;
scalar_t dist2 = 0.0;
for (int64_t d = 0; d < D; ++d) {
float diff = points[batch_idx][selected][d] - points[batch_idx][p][d];
scalar_t diff = points[batch_idx][selected][d] - points[batch_idx][p][d];
dist2 += (diff * diff);
}

// If the distance of point p to the last selected point is
// less than the previous minimum distance of p to the set of selected
// points, then updated the corresponding value in min_point_dist
// so it always contains the min distance.
const float p_min_dist = min(dist2, min_point_dist[batch_idx][p]);
const scalar_t p_min_dist = min(dist2, min_point_dist[batch_idx][p]);
min_point_dist[batch_idx][p] = p_min_dist;

// Update the max distance and point idx for this thread.
Expand All @@ -88,7 +88,7 @@ __global__ void FarthestPointSamplingKernel(
selected =
BlockReduce(temp_storage)
.Reduce(
cub::KeyValuePair<int64_t, float>(max_dist_idx, max_dist),
cub::KeyValuePair<int64_t, scalar_t>(max_dist_idx, max_dist),
cub::ArgMax(),
block_size)
.key;
Expand All @@ -109,13 +109,16 @@ at::Tensor FarthestPointSamplingCuda(
const at::Tensor& points, // (N, P, 3)
const at::Tensor& lengths, // (N,)
const at::Tensor& K, // (N,)
const at::Tensor& start_idxs) {
const at::Tensor& start_idxs, // (N, P)
const at::Tensor& min_point_dist
) {
// Check inputs are on the same device
at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2},
k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4};
at::CheckedFrom c = "FarthestPointSamplingCuda";
at::checkAllSameGPU(c, {p_t, lengths_t, k_t, start_idxs_t});
at::checkAllSameType(c, {lengths_t, k_t, start_idxs_t});
at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2},
k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4}, min_point_dist_t{min_point_dist, "min_point_dist", 5};
at::CheckedFrom c = "FarthestPointSamplingCuda";
at::checkAllSameGPU(c, {p_t, lengths_t, k_t, start_idxs_t, min_point_dist_t});
at::checkAllSameType(c, {p_t, min_point_dist_t});
at::checkAllSameType(c, {lengths_t, k_t, start_idxs_t});

// Set the device for the kernel launch based on the device of points
at::cuda::CUDAGuard device_guard(points.device());
Expand All @@ -135,7 +138,6 @@ at::Tensor FarthestPointSamplingCuda(

// Initialize the output tensor with the sampled indices
auto idxs = at::full({N, max_K}, -1, lengths.options());
auto min_point_dist = at::full({N, P}, 1e10, points.options());

if (N == 0 || P == 0) {
AT_CUDA_CHECK(cudaGetLastError());
Expand All @@ -158,15 +160,10 @@ at::Tensor FarthestPointSamplingCuda(
const size_t threads = max(min(1 << points_pow_2, MAX_THREADS_PER_BLOCK), 2);

// Create the accessors
auto points_a = points.packed_accessor64<float, 3, at::RestrictPtrTraits>();
auto lengths_a =
lengths.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>();
auto lengths_a = lengths.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>();
auto K_a = K.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>();
auto idxs_a = idxs.packed_accessor64<int64_t, 2, at::RestrictPtrTraits>();
auto start_idxs_a =
start_idxs.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>();
auto min_point_dist_a =
min_point_dist.packed_accessor64<float, 2, at::RestrictPtrTraits>();
auto start_idxs_a = start_idxs.packed_accessor64<int64_t, 1, at::RestrictPtrTraits>();

// TempStorage for the reduction uses static shared memory only.
size_t shared_mem = 0;
Expand All @@ -175,50 +172,124 @@ at::Tensor FarthestPointSamplingCuda(
// block.
switch (threads) {
case 1024:
FarthestPointSamplingKernel<1024>
<<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "fps_kernel_cuda", ([&] {
FarthestPointSamplingKernel<1024, scalar_t>
<<<blocks, threads, shared_mem, stream>>>(
points.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(), lengths_a, K_a, idxs_a,
min_point_dist.packed_accessor64<scalar_t, 2, at::RestrictPtrTraits>(), start_idxs_a
);
})
);
break;
case 512:
FarthestPointSamplingKernel<512><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "fps_kernel_cuda", ([&] {
FarthestPointSamplingKernel<512, scalar_t>
<<<blocks, threads, shared_mem, stream>>>(
points.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(), lengths_a, K_a, idxs_a,
min_point_dist.packed_accessor64<scalar_t, 2, at::RestrictPtrTraits>(), start_idxs_a
);
})
);
break;
case 256:
FarthestPointSamplingKernel<256><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "fps_kernel_cuda", ([&] {
FarthestPointSamplingKernel<256, scalar_t>
<<<blocks, threads, shared_mem, stream>>>(
points.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(), lengths_a, K_a, idxs_a,
min_point_dist.packed_accessor64<scalar_t, 2, at::RestrictPtrTraits>(), start_idxs_a
);
}));
break;
case 128:
FarthestPointSamplingKernel<128><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "fps_kernel_cuda", ([&] {
FarthestPointSamplingKernel<128, scalar_t>
<<<blocks, threads, shared_mem, stream>>>(
points.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(), lengths_a, K_a, idxs_a,
min_point_dist.packed_accessor64<scalar_t, 2, at::RestrictPtrTraits>(), start_idxs_a
);
})
);
break;
case 64:
FarthestPointSamplingKernel<64><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "fps_kernel_cuda", ([&] {
FarthestPointSamplingKernel<64, scalar_t>
<<<blocks, threads, shared_mem, stream>>>(
points.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(), lengths_a, K_a, idxs_a,
min_point_dist.packed_accessor64<scalar_t, 2, at::RestrictPtrTraits>(), start_idxs_a
);
})
);
break;
case 32:
FarthestPointSamplingKernel<32><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "fps_kernel_cuda", ([&] {
FarthestPointSamplingKernel<32, scalar_t>
<<<blocks, threads, shared_mem, stream>>>(
points.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(), lengths_a, K_a, idxs_a,
min_point_dist.packed_accessor64<scalar_t, 2, at::RestrictPtrTraits>(), start_idxs_a
);
})
);
break;
case 16:
FarthestPointSamplingKernel<16><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "fps_kernel_cuda", ([&] {
FarthestPointSamplingKernel<16, scalar_t>
<<<blocks, threads, shared_mem, stream>>>(
points.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(), lengths_a, K_a, idxs_a,
min_point_dist.packed_accessor64<scalar_t, 2, at::RestrictPtrTraits>(), start_idxs_a
);
})
);
break;
case 8:
FarthestPointSamplingKernel<8><<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "fps_kernel_cuda", ([&] {
FarthestPointSamplingKernel<8, scalar_t>
<<<blocks, threads, shared_mem, stream>>>(
points.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(), lengths_a, K_a, idxs_a,
min_point_dist.packed_accessor64<scalar_t, 2, at::RestrictPtrTraits>(), start_idxs_a
);
})
);
break;
case 4:
FarthestPointSamplingKernel<4><<<threads, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "fps_kernel_cuda", ([&] {
FarthestPointSamplingKernel<4, scalar_t>
<<<threads, threads, shared_mem, stream>>>(
points.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(), lengths_a, K_a, idxs_a,
min_point_dist.packed_accessor64<scalar_t, 2, at::RestrictPtrTraits>(), start_idxs_a
);
})
);
break;
case 2:
FarthestPointSamplingKernel<2><<<threads, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "fps_kernel_cuda", ([&] {
FarthestPointSamplingKernel<2, scalar_t>
<<<threads, threads, shared_mem, stream>>>(
points.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(), lengths_a, K_a, idxs_a,
min_point_dist.packed_accessor64<scalar_t, 2, at::RestrictPtrTraits>(), start_idxs_a
);
})
);
break;
default:
FarthestPointSamplingKernel<1024>
<<<blocks, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "fps_kernel_cuda", ([&] {
FarthestPointSamplingKernel<1024, scalar_t>
<<<blocks, threads, shared_mem, stream>>>(
points.packed_accessor64<scalar_t, 3, at::RestrictPtrTraits>(), lengths_a, K_a, idxs_a,
min_point_dist.packed_accessor64<scalar_t, 2, at::RestrictPtrTraits>(), start_idxs_a
);
})
);
}

AT_CUDA_CHECK(cudaGetLastError());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ at::Tensor FarthestPointSamplingCuda(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const at::Tensor& start_idxs);
const at::Tensor& start_idxs,
const at::Tensor& min_point_dist);

at::Tensor FarthestPointSamplingCpu(
const at::Tensor& points,
Expand All @@ -56,14 +57,16 @@ at::Tensor FarthestPointSampling(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const at::Tensor& start_idxs) {
const at::Tensor& start_idxs,
const at::Tensor& min_point_dist) {
if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(points);
CHECK_CUDA(lengths);
CHECK_CUDA(K);
CHECK_CUDA(start_idxs);
return FarthestPointSamplingCuda(points, lengths, K, start_idxs);
CHECK_CUDA(min_point_dist);
return FarthestPointSamplingCuda(points, lengths, K, start_idxs, min_point_dist);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
Expand Down
7 changes: 0 additions & 7 deletions pytorch3d/ops/ball_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,6 @@ def forward(ctx, p1, p2, lengths1, lengths2, K, radius):
@once_differentiable
def backward(ctx, grad_dists, grad_idx):
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
# TODO(gkioxari) Change cast to floats once we add support for doubles.
if not (grad_dists.dtype == torch.float32):
grad_dists = grad_dists.float()
if not (p1.dtype == torch.float32):
p1 = p1.float()
if not (p2.dtype == torch.float32):
p2 = p2.float()

# Reuse the KNN backward function
# by default, norm is 2
Expand Down
10 changes: 7 additions & 3 deletions pytorch3d/ops/sample_farthest_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,23 +74,27 @@ def sample_farthest_points(
raise ValueError("K and points must have the same batch dimension")

# Check dtypes are correct and convert if necessary
if not (points.dtype == torch.float32):
points = points.to(torch.float32)
if not (lengths.dtype == torch.int64):
lengths = lengths.to(torch.int64)
if not (K.dtype == torch.int64):
K = K.to(torch.int64)

# Generate the starting indices for sampling
start_idxs = torch.zeros_like(lengths)

# Generate the minimum point distance array
min_point_dist = torch.full(
(N, P), torch.finfo(points.dtype).max, dtype=points.dtype, device=device
)

if random_start_point:
for n in range(N):
# pyre-fixme[6]: For 1st param expected `int` but got `Tensor`.
start_idxs[n] = torch.randint(high=lengths[n], size=(1,)).item()

with torch.no_grad():
# pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`.
idx = _C.sample_farthest_points(points, lengths, K, start_idxs)
idx = _C.sample_farthest_points(points, lengths, K, start_idxs, min_point_dist)
sampled_points = masked_gather(points, idx)

return sampled_points, idx
Expand Down

0 comments on commit 8831615

Please sign in to comment.