Skip to content

Commit

Permalink
Remove unnecessary argsort in CUDAHashMap.keys() (#400)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Feb 6, 2025
1 parent 8bc4e7a commit 5ec2ff3
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pyg_lib/csrc/classes/cuda/hash_map.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ struct CUDAHashMapImpl : HashMapImpl {

map_->retrieve_all(key_data, value_data);

return key.index_select(0, value.argsort());
const auto perm = at::empty_like(value);
perm[value] = at::arange(value.numel(), value.options());

return key.index_select(0, perm);
}

private:
Expand Down

0 comments on commit 5ec2ff3

Please sign in to comment.