diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f8ca4155..f4d8b7fc4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added ### Changed - Added `--biased` parameter to run benchmarks for biased sampling ([#267](https://github.com/pyg-team/pyg-lib/pull/267)) +- Improved speed of biased sampling ([#270](https://github.com/pyg-team/pyg-lib/pull/270)) ### Removed ## [0.3.0] - 2023-10-11 diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index cfc679d01..f26ee13af 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -223,7 +223,19 @@ class NeighborSampler { // Case 2: Multinomial sampling: else { - const auto index = at::multinomial(weight, count, replace); + at::Tensor index; + if (replace) { + // at::multinomial only has good perfomance for `replace=true`, see: + // https://github.com/pytorch/pytorch/issues/11931 + index = at::multinomial(weight, count, replace); + } else { + // For `replace=false`, we make use of the implementation of the + // "Weighted Random Sampling" paper: + // https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf + const auto rand = at::empty_like(weight).uniform_(); + const auto key = (rand.log() / weight); + index = std::get<1>(key.topk(count)); + } const auto index_data = index.data_ptr(); for (size_t i = 0; i < index.numel(); ++i) { add(row_start + index_data[i], global_src_node, local_src_node,