Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Bottleneck in KNN Kernel with larger K values #231

Open
IlliaOvcharenko opened this issue Nov 8, 2024 · 0 comments
Open

Performance Bottleneck in KNN Kernel with larger K values #231

IlliaOvcharenko opened this issue Nov 8, 2024 · 0 comments

Comments

@IlliaOvcharenko
Copy link
Contributor

In my project, I am using PointCNN for a segmentation task. Recently, I did a performance testing using Nvidia Nsight System to identify potential bottlenecks. During these test, I observed that the KNN kernel consumed approximately 89% of the total inference time, which seems abnormally high.

Below, I have included several screenshots that highlight this performance issue:

  1. CUDA Kernel Summary. The KNN kernel took ~89% of the total inference time.
Знімок_екрана_2024-10-31_о_17 23 59
  1. Single batch inference analysis. The KNN operation within the dec3 layer consumed nearly half of the inference time.
pointcnn-batch-timeline
  1. KNN/FPS execution time and input shapes. The tests were conducted with a batch size of 24, where each item consisted of 8192 point samples.

pointcnn-inference-iteration

The execution time of the KNN operation appears to increase exponentially as the value of the k parameter grows. Below are some examples of execution times with varying k values (same number of input points but with different numbers of neighbors):

Layer k Execution Time (ms)
conv1 8 13
dec4 32 241
dec3 48 681

I reviewed the CUDA implementation of KNN and suspect that the main reason for this slowdown is related to adjusting to best_dist and best_idx arrays.

  // n_y is current request point, 
  // for which we going to calculate k nearest neighbors across n_x points 

  // for every input point 
  for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
    // ...
    // calculate distance from n_y to n_x, save into tmp_dist and 
    // ...

    // adjust best_dist and best_idx arrays on every step
    // probably the slowest part with increased k value
    for (int64_t e1 = 0; e1 < k; e1++) {
      if (best_dist[e1] > tmp_dist) {
        for (int64_t e2 = k - 1; e2 > e1; e2--) {
          best_dist[e2] = best_dist[e2 - 1];
          best_idx[e2] = best_idx[e2 - 1];
        }
        best_dist[e1] = tmp_dist;
        best_idx[e1] = n_x;
        break;
      }
    }
  }

So, I think there are several main issues with a current code:

  1. Managing the best_dist array inside knn_kernel appears to take significant time and is not the most efficient code to run on GPU.
  2. Recomputing distances for each FPS/KNN call seems inefficient, probably there is sense to do it once.

Questions

  1. Is there a fundamental issue with my implementation or an incorrect usage of the KNN/FPS operations?
  2. Would pre-computing the distances between points on the CPU within the data loader be a good option to consider?

Implementation Details

Below is the PointCNN implementation used in this project (the model is run through torch.compile, excluding the KNN and FPS operations):

class PointCnnSegm(torch.nn.Module):
    def __init__(
        self,
        num_classes: int,
        x_features: int = 3,
        fps_ratio: list[float] = [0.1, 0.5, 0.334],
    ):
        super().__init__()
        self.conv1 = XConv(x_features, 256, dim=3, \
                           kernel_size=8, hidden_channels=128)
        self.conv2 = XConv(256, 256, dim=3, kernel_size=12, dilation=2)
        self.conv3 = XConv(256, 512, dim=3, kernel_size=16, dilation=2)
        self.conv4 = XConv(512, 1024, dim=3, kernel_size=16, dilation=6)

        self.dec1 = XConv(1024+512, 512, dim=3, kernel_size=16, dilation=6)
        self.dec2 = XConv(512+256, 256, dim=3, kernel_size=12, dilation=6)
        self.dec3 = XConv(256+256, 256, dim=3, kernel_size=8, dilation=6)
        self.dec4 = XConv(256+256, 256, dim=3, kernel_size=8, dilation=4)

        self.head = torch.nn.Sequential(
            torch.nn.Linear(256, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, num_classes)
        )

        self.fps_ratio = fps_ratio

    def forward(self, enc1_x, enc1_pos, enc1_batch):

        enc1_x = relu(self.conv1(enc1_x, enc1_pos, enc1_batch))

        idx = fps(enc1_pos, enc1_batch, ratio=self.fps_ratio[0])
        enc2_x, enc2_pos, enc2_batch = \
            enc1_x[idx], enc1_pos[idx], enc1_batch[idx]
        enc2_x = relu(self.conv2(enc2_x, enc2_pos, enc2_batch))

        idx = fps(enc2_pos, enc2_batch, ratio=self.fps_ratio[1])
        enc3_x, enc3_pos, enc3_batch = \
            enc2_x[idx], enc2_pos[idx], enc2_batch[idx]
        enc3_x = relu(self.conv3(enc3_x, enc3_pos, enc3_batch))

        idx = fps(enc3_pos, enc3_batch, ratio=self.fps_ratio[2])
        enc4_x, enc4_pos, enc4_batch = \
            enc3_x[idx], enc3_pos[idx], enc3_batch[idx]
        enc4_x = relu(self.conv4(enc4_x, enc4_pos, enc4_batch))


        dec1_x = knn_interpolate(enc4_x, enc4_pos, enc3_pos, \
                                 enc4_batch, enc3_batch, k=3)
        dec1_x = torch.cat([dec1_x, enc3_x], dim=1)
        dec1_x = relu(self.dec1(dec1_x, enc3_pos, enc3_batch))

        dec2_x = knn_interpolate(dec1_x, enc3_pos, enc2_pos, \
                                 enc3_batch, enc2_batch, k=3)
        dec2_x = torch.cat([dec2_x, enc2_x], dim=1)
        dec2_x = relu(self.dec2(dec2_x, enc2_pos, enc2_batch))

        dec3_x = knn_interpolate(dec2_x, enc2_pos, enc1_pos, \
                                 enc2_batch, enc1_batch, k=3)
        dec3_x = torch.cat([dec3_x, enc1_x], dim=1)
        dec3_x = relu(self.dec3(dec3_x, enc1_pos, enc1_batch))

        dec4_x = torch.cat([dec3_x, enc1_x], dim=1)
        dec4_x = relu(self.dec4(dec4_x, enc1_pos, enc1_batch))

        out = self.head(dec4_x)
        return out

XConv implementation:

class XConv(torch.nn.Module):
    def __init__(self, in_channels: int, out_channels: int, dim: int,
                 kernel_size: int, hidden_channels: int | None = None,
                 dilation: int = 1, bias: bool = True, num_workers: int = 1):
        super().__init__()

        self.in_channels = in_channels
        if hidden_channels is None:
            hidden_channels = in_channels // 4
        assert hidden_channels > 0
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.dim = dim
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.num_workers = num_workers

        C_in, C_delta, C_out = in_channels, hidden_channels, out_channels
        D, K = dim, kernel_size

        self.mlp1 = torch.nn.Sequential(
            torch.nn.Linear(dim, C_delta),
            torch.nn.ELU(),
            torch.nn.BatchNorm1d(C_delta),
            torch.nn.Linear(C_delta, C_delta),
            torch.nn.ELU(),
            torch.nn.BatchNorm1d(C_delta),
            Reshape(-1, K, C_delta),
        )

        self.mlp2 = torch.nn.Sequential(
            torch.nn.Linear(D * K, K**2),
            torch.nn.ELU(),
            torch.nn.BatchNorm1d(K**2),
            Reshape(-1, K, K),
            torch.nn.Conv1d(K, K**2, K, groups=K),
            torch.nn.ELU(),
            torch.nn.BatchNorm1d(K**2),
            Reshape(-1, K, K),
            torch.nn.Conv1d(K, K**2, K, groups=K),
            torch.nn.BatchNorm1d(K**2),
            Reshape(-1, K, K),
        )

        C_in = C_in + C_delta
        depth_multiplier = int(ceil(C_out / C_in))
        self.conv = torch.nn.Sequential(
            torch.nn.Conv1d(C_in, C_in * depth_multiplier, K, groups=C_in),
            Reshape(-1, C_in * depth_multiplier),
            torch.nn.Linear(C_in * depth_multiplier, C_out, bias=bias),
        )

        self.reset_parameters()

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        reset(self.mlp1)
        reset(self.mlp2)
        reset(self.conv)

    def forward(
        self,
        x: torch.Tensor,
        pos: torch.Tensor,
        batch: torch.Tensor | None  = None
    ):
        r"""Runs the forward pass of the module."""
        pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos
        (N, D), K = pos.size(), self.kernel_size

        edge_index = knn_graph(pos, K * self.dilation, batch, loop=True,
                               flow='target_to_source',
                               num_workers=self.num_workers)

        if self.dilation > 1:
            edge_index = edge_index[:, ::self.dilation]

        row, col = edge_index[0], edge_index[1]

        pos = pos[col] - pos[row]

        x_star = self.mlp1(pos)
        if x is not None:
            x = x.unsqueeze(-1) if x.dim() == 1 else x
            x = x[col].view(N, K, self.in_channels)
            x_star = torch.cat([x_star, x], dim=-1)
        x_star = x_star.transpose(1, 2).contiguous()

        transform_matrix = self.mlp2(pos.view(N, K * D))

        x_transformed = torch.matmul(x_star, transform_matrix)

        out = self.conv(x_transformed)

        return out

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels})')

Environment Details

  • PyTorch == 2.2.1
  • PyG == 2.5.0
  • Torch Cluster == 1.6.3
  • Python 3.10
  • NVIDIA GeForce RTX 4090
  • CUDA Version 12.5

Thank you for your assistance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant