diff --git a/csrc/cpu/rw_cpu.cpp b/csrc/cpu/rw_cpu.cpp index c26b28e..1ee7aa6 100644 --- a/csrc/cpu/rw_cpu.cpp +++ b/csrc/cpu/rw_cpu.cpp @@ -137,3 +137,168 @@ random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, return std::make_tuple(n_out, e_out); } + + +void compute_cdf(const int64_t *rowptr, const float_t *edge_weight, + float_t *edge_weight_cdf, int64_t numel) { + /* Convert edge weights to CDF as given in [1] + + [1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L148 + */ + at::parallel_for(0, numel - 1, at::internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) { + for(int64_t i = begin; i < end; i++) { + int64_t row_start = rowptr[i], row_end = rowptr[i + 1]; + float_t acc = 0.0; + + for(int64_t j = row_start; j < row_end; j++) { + acc += edge_weight[j]; + edge_weight_cdf[j] = acc; + } + } + }); +} + + +int64_t get_offset(const float_t *edge_weight, int64_t start, int64_t end) { + /* + The implementation given in [1] utilizes the `searchsorted` function in Numpy. + It is also available in PyTorch and its C++ API (via `at::searchsorted()`). + However, the implementation is adopted to the general case where the searched + values can be a multidimensional tensor. In our case, we have a 1D tensor of + edge weights (in form of a Cumulative Distribution Function) and a single + value, whose position we want to compute. To eliminate the overhead introduced + in the PyTorch implementation, one can examine the source code of + `searchsorted` [2] and find that for our case the whole function call can be + reduced to calling the `cus_lower_bound()` function. Unfortunately, we cannot + access it directly (the namespace is not exposed to the public API), but the + implementation is just a simple binary search. The code was copied here and + reduced to the bare minimum. + + [1] https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69 + [2] https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Bucketization.cpp + */ + float_t value = ((float_t)rand() / RAND_MAX); // [0, 1) + int64_t original_start = start; + + while (start < end) { + const int64_t mid = start + ((end - start) >> 1); + const float_t mid_val = edge_weight[mid]; + if (!(mid_val >= value)) { + start = mid + 1; + } + else { + end = mid; + } + } + + return start - original_start; +} + +// See: https://louisabraham.github.io/articles/node2vec-sampling.html +// See also: https://github.com/louisabraham/fastnode2vec/blob/master/fastnode2vec/graph.py#L69 +void rejection_sampling_weighted(const int64_t *rowptr, const int64_t *col, + const float_t *edge_weight_cdf, int64_t *start, + int64_t *n_out, int64_t *e_out, + const int64_t numel, const int64_t walk_length, + const double p, const double q) { + + double max_prob = fmax(fmax(1. / p, 1.), 1. / q); + double prob_0 = 1. / p / max_prob; + double prob_1 = 1. / max_prob; + double prob_2 = 1. / q / max_prob; + + int64_t grain_size = at::internal::GRAIN_SIZE / walk_length; + at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) { + for (auto n = begin; n < end; n++) { + int64_t t = start[n], v, x, e_cur, row_start, row_end; + + n_out[n * (walk_length + 1)] = t; + + row_start = rowptr[t], row_end = rowptr[t + 1]; + + if (row_end - row_start == 0) { + e_cur = -1; + v = t; + } else { + e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end); + v = col[e_cur]; + } + n_out[n * (walk_length + 1) + 1] = v; + e_out[n * walk_length] = e_cur; + + for (auto l = 1; l < walk_length; l++) { + row_start = rowptr[v], row_end = rowptr[v + 1]; + + if (row_end - row_start == 0) { + e_cur = -1; + x = v; + } else if (row_end - row_start == 1) { + e_cur = row_start; + x = col[e_cur]; + } else { + if (p == 1 and q == 1) { + e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end); + x = col[e_cur]; + } + else { + while (true) { + e_cur = row_start + get_offset(edge_weight_cdf, row_start, row_end); + x = col[e_cur]; + + auto r = ((double)rand() / (RAND_MAX)); // [0, 1) + + if (x == t && r < prob_0) + break; + else if (is_neighbor(rowptr, col, x, t) && r < prob_1) + break; + else if (r < prob_2) + break; + } + } + } + + n_out[n * (walk_length + 1) + (l + 1)] = x; + e_out[n * walk_length + l] = e_cur; + t = v; + v = x; + } + } + }); +} + + +std::tuple +random_walk_weighted_cpu(torch::Tensor rowptr, torch::Tensor col, + torch::Tensor edge_weight, torch::Tensor start, + int64_t walk_length, double p, double q) { + CHECK_CPU(rowptr); + CHECK_CPU(col); + CHECK_CPU(edge_weight); + CHECK_CPU(start); + + CHECK_INPUT(rowptr.dim() == 1); + CHECK_INPUT(col.dim() == 1); + CHECK_INPUT(edge_weight.dim() == 1); + CHECK_INPUT(start.dim() == 1); + + auto n_out = torch::empty({start.size(0), walk_length + 1}, start.options()); + auto e_out = torch::empty({start.size(0), walk_length}, start.options()); + + auto rowptr_data = rowptr.data_ptr(); + auto col_data = col.data_ptr(); + auto edge_weight_data = edge_weight.data_ptr(); + auto start_data = start.data_ptr(); + auto n_out_data = n_out.data_ptr(); + auto e_out_data = e_out.data_ptr(); + + auto edge_weight_cdf = torch::empty({edge_weight.size(0)}, edge_weight.options()); + auto edge_weight_cdf_data = edge_weight_cdf.data_ptr(); + + compute_cdf(rowptr_data, edge_weight_data, edge_weight_cdf_data, rowptr.numel()); + + rejection_sampling_weighted(rowptr_data, col_data, edge_weight_cdf_data, + start_data, n_out_data, e_out_data, start.numel(), + walk_length, p, q); + + return std::make_tuple(n_out, e_out); +} diff --git a/csrc/cpu/rw_cpu.h b/csrc/cpu/rw_cpu.h index 6ea9fcd..8d0c108 100644 --- a/csrc/cpu/rw_cpu.h +++ b/csrc/cpu/rw_cpu.h @@ -5,3 +5,8 @@ std::tuple random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, int64_t walk_length, double p, double q); + +std::tuple +random_walk_weighted_cpu(torch::Tensor rowptr, torch::Tensor col, + torch::Tensor edge_weight, torch::Tensor start, + int64_t walk_length, double p, double q); diff --git a/csrc/cuda/rw_cuda.cu b/csrc/cuda/rw_cuda.cu index 763b861..bffdc34 100644 --- a/csrc/cuda/rw_cuda.cu +++ b/csrc/cuda/rw_cuda.cu @@ -150,3 +150,163 @@ random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, return std::make_tuple(n_out.t().contiguous(), e_out.t().contiguous()); } + + +__global__ void cdf_kernel(const int64_t *rowptr, const float_t *edge_weight, + float_t *edge_weight_cdf, int64_t numel) { + const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_idx < numel - 1) { + int64_t row_start = rowptr[thread_idx], row_end = rowptr[thread_idx + 1]; + + float_t acc = 0.0; + + for(int64_t i = row_start; i < row_end; i++) { + acc += edge_weight[i]; + edge_weight_cdf[i] = acc; + } + } +} + +__device__ void get_offset(const float_t *edge_weight, int64_t start, int64_t end, + float_t value, int64_t *position_out) { + int64_t original_start = start; + + while (start < end) { + const int64_t mid = start + ((end - start) >> 1); + const float_t mid_val = edge_weight[mid]; + if (!(mid_val >= value)) { + start = mid + 1; + } + else { + end = mid; + } + } + + *position_out = start - original_start; +} + +__global__ void +rejection_sampling_weighted_kernel(unsigned int seed, const int64_t *rowptr, + const int64_t *col, const float_t *edge_weight_cdf, + const int64_t *start, int64_t *n_out, + int64_t *e_out, const int64_t walk_length, + const int64_t numel, const double p, + const double q) { + + curandState_t state; + curand_init(seed, 0, 0, &state); + + double max_prob = fmax(fmax(1. / p, 1.), 1. / q); + double prob_0 = 1. / p / max_prob; + double prob_1 = 1. / max_prob; + double prob_2 = 1. / q / max_prob; + + const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_idx < numel) { + int64_t t = start[thread_idx], v, x, e_cur, row_start, row_end, offset; + + n_out[thread_idx] = t; + + row_start = rowptr[t], row_end = rowptr[t + 1]; + + if (row_end - row_start == 0) { + e_cur = -1; + v = t; + } else { + get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset); + e_cur = row_start + offset; + v = col[e_cur]; + } + + n_out[numel + thread_idx] = v; + e_out[thread_idx] = e_cur; + + for (int64_t l = 1; l < walk_length; l++) { + row_start = rowptr[v], row_end = rowptr[v + 1]; + + if (row_end - row_start == 0) { + e_cur = -1; + x = v; + } else if (row_end - row_start == 1) { + e_cur = row_start; + x = col[e_cur]; + } else { + if (p == 1 and q == 1) { + get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset); + e_cur = row_start + offset; + x = col[e_cur]; + } + else { + while (true) { + get_offset(edge_weight_cdf, row_start, row_end, curand_uniform(&state), &offset); + e_cur = row_start + offset; + x = col[e_cur]; + + double r = curand_uniform(&state); // (0, 1] + + if (x == t && r < prob_0) + break; + + bool is_neighbor = false; + row_start = rowptr[x], row_end = rowptr[x + 1]; + for (int64_t i = row_start; i < row_end; i++) { + if (col[i] == t) { + is_neighbor = true; + break; + } + } + + if (is_neighbor && r < prob_1) + break; + else if (r < prob_2) + break; + } + } + } + + n_out[(l + 1) * numel + thread_idx] = x; + e_out[l * numel + thread_idx] = e_cur; + t = v; + v = x; + } + } +} + + +std::tuple +random_walk_weighted_cuda(torch::Tensor rowptr, torch::Tensor col, + torch::Tensor edge_weight, torch::Tensor start, + int64_t walk_length, double p, double q) { + CHECK_CUDA(rowptr); + CHECK_CUDA(col); + CHECK_CUDA(edge_weight); + CHECK_CUDA(start); + cudaSetDevice(rowptr.get_device()); + + CHECK_INPUT(rowptr.dim() == 1); + CHECK_INPUT(col.dim() == 1); + CHECK_INPUT(edge_weight.dim() == 1); + CHECK_INPUT(start.dim() == 1); + + auto n_out = torch::empty({walk_length + 1, start.size(0)}, start.options()); + auto e_out = torch::empty({walk_length, start.size(0)}, start.options()); + + auto stream = at::cuda::getCurrentCUDAStream(); + + auto edge_weight_cdf = torch::empty({edge_weight.size(0)}, edge_weight.options()); + + cdf_kernel<<>>( + rowptr.data_ptr(), edge_weight.data_ptr(), + edge_weight_cdf.data_ptr(), rowptr.numel()); + + rejection_sampling_weighted_kernel<<>>( + time(NULL), rowptr.data_ptr(), col.data_ptr(), + edge_weight_cdf.data_ptr(), start.data_ptr(), + n_out.data_ptr(), e_out.data_ptr(), + walk_length, start.numel(), p, q); + + return std::make_tuple(n_out.t().contiguous(), e_out.t().contiguous()); +} + diff --git a/csrc/cuda/rw_cuda.h b/csrc/cuda/rw_cuda.h index 79f4139..b3919d8 100644 --- a/csrc/cuda/rw_cuda.h +++ b/csrc/cuda/rw_cuda.h @@ -5,3 +5,8 @@ std::tuple random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, int64_t walk_length, double p, double q); + +std::tuple +random_walk_weighted_cuda(torch::Tensor rowptr, torch::Tensor col, + torch::Tensor edge_weight, torch::Tensor start, + int64_t walk_length, double p, double q); diff --git a/csrc/rw.cpp b/csrc/rw.cpp index 0f8de62..5d48646 100644 --- a/csrc/rw.cpp +++ b/csrc/rw.cpp @@ -33,5 +33,21 @@ random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, } } +CLUSTER_API std::tuple +random_walk_weighted(torch::Tensor rowptr, torch::Tensor col, + torch::Tensor edge_weight, torch::Tensor start, + int64_t walk_length, double p, double q) { + if (rowptr.device().is_cuda()) { +#ifdef WITH_CUDA + return random_walk_weighted_cuda(rowptr, col, edge_weight, start, walk_length, p, q); +#else + AT_ERROR("Not compiled with CUDA support"); +#endif + } else { + return random_walk_weighted_cpu(rowptr, col, edge_weight, start, walk_length, p, q); + } +} + static auto registry = - torch::RegisterOperators().op("torch_cluster::random_walk", &random_walk); + torch::RegisterOperators().op("torch_cluster::random_walk", &random_walk) + .op("torch_cluster::random_walk_weighted", &random_walk_weighted); diff --git a/test/test_rw.py b/test/test_rw.py index 0ff91a4..345d971 100644 --- a/test/test_rw.py +++ b/test/test_rw.py @@ -77,3 +77,44 @@ def test_rw_small_with_edge_indices(device): [1, 0, 1, 0], [-1, -1, -1, -1], ] + + +@pytest.mark.parametrize('device', devices) +def test_rw_large_with_edge_weights(device): + row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device) + col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device) + start = tensor([0, 1, 2, 3, 4], torch.long, device) + edge_weight = tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], torch.float, device) + walk_length = 10 + + out = random_walk( + row, col, start, walk_length, + edge_weight=edge_weight, + ) + assert out[:, 0].tolist() == start.tolist() + + for n in range(start.size(0)): + cur = start[n].item() + for i in range(1, walk_length): + assert out[n, i].item() in col[row == cur].tolist() + cur = out[n, i].item() + + +@pytest.mark.parametrize('device', devices) +def test_rw_small_with_edge_weights(device): + row = tensor([0, 1], torch.long, device) + col = tensor([1, 0], torch.long, device) + start = tensor([0, 1, 2], torch.long, device) + edge_weight = tensor([1, 1], torch.float, device) + walk_length = 4 + + out = random_walk( + row, col, start, walk_length, + num_nodes=3, + edge_weight=edge_weight, + ) + assert out.tolist() == [ + [0, 1, 0, 1, 0], + [1, 0, 1, 0, 1], + [2, 2, 2, 2, 2], + ] diff --git a/torch_cluster/rw.py b/torch_cluster/rw.py index 12e0683..672a75e 100644 --- a/torch_cluster/rw.py +++ b/torch_cluster/rw.py @@ -15,6 +15,7 @@ def random_walk( coalesced: bool = True, num_nodes: Optional[int] = None, return_edge_indices: bool = False, + edge_weight: Optional[Tensor] = None, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Samples random walks of length :obj:`walk_length` from all node indices in :obj:`start` in the graph given by :obj:`(row, col)` as described in the @@ -39,6 +40,8 @@ def random_walk( return_edge_indices (bool, optional): Whether to additionally return the indices of edges traversed during the random walk. (default: :obj:`False`) + edge_weight (Tensor, optional): Weights of edges given by `row` and + `col` (default: :obj:`None`) :rtype: :class:`LongTensor` """ @@ -54,9 +57,17 @@ def random_walk( rowptr = row.new_zeros(num_nodes + 1) torch.cumsum(deg, 0, out=rowptr[1:]) - node_seq, edge_seq = torch.ops.torch_cluster.random_walk( - rowptr, col, start, walk_length, p, q, - ) + if edge_weight is None: + node_seq, edge_seq = torch.ops.torch_cluster.random_walk( + rowptr, col, start, walk_length, p, q, + ) + else: + # Normalize edge weights by node degrees + edge_weight = edge_weight / deg[row] + + node_seq, edge_seq = torch.ops.torch_cluster.random_walk_weighted( + rowptr, col, edge_weight, start, walk_length, p, q, + ) if return_edge_indices: return node_seq, edge_seq