Skip to content

Commit

Permalink
Fix permute_multi_embedding kernel (#3227)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3227

X-link: facebookresearch/FBGEMM#325

Looks like a typo to use `permute_id = threadIdx.y + blockIdx.x * blockDim.x` which should be `blockDim.y`. This doesn't affect Nvidia because blockDim.x and y are both 32 (32 threads per warp + 32 warps). For AMD GPU, blockDim.x is 64 and blockDim.y is 16, causing numerical issues.

Reviewed By: leitian, jianyuh, joebos

Differential Revision: D63936776

fbshipit-source-id: cfdf0ff24e41a8ffd137ce066f2b82f3c47399b5
  • Loading branch information
xw285cornell authored and facebook-github-bot committed Oct 5, 2024
1 parent 9408072 commit 1815f89
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ using Tensor = at::Tensor;

namespace fbgemm_gpu {

// Kernerl for permute pooled embedding op.
// Kernel for permute pooled embedding op.
// This kernel is moving D elements per warp.
template <typename scalar_t, bool reverse_permute>
__global__ void permute_multi_embs_kernel(
Expand All @@ -40,7 +40,7 @@ __global__ void permute_multi_embs_kernel(
const int32_t permute_size) {
// workers in a warp handle exact one permute (of a feature/key)
const int32_t worker_id = threadIdx.x;
const int32_t permute_id = threadIdx.y + blockIdx.x * blockDim.x;
const int32_t permute_id = threadIdx.y + blockIdx.x * blockDim.y;
const int32_t batch_id = blockIdx.y + gridDim.y * blockIdx.z;
if (batch_id >= batch_size) {
return;
Expand Down

0 comments on commit 1815f89

Please sign in to comment.