From cc4696bddedfb2a3479b9acf9cde562c2c1f237b Mon Sep 17 00:00:00 2001 From: Piotr Bielak Date: Mon, 8 Aug 2022 16:02:38 +0200 Subject: [PATCH] Allow returning edge indices from random walk (#139) This commit adds an optional argument in the `random_walk` function, namely `return_edge_indices`. The default behaviour is not changed, but if a user wants to directly use the edges visited by the random walker, we can return the indices of those edges by setting `return_edge_indices` to `True`. New cases are also added to the test suite. --- test/test_rw.py | 51 ++++++++++++++++++++++++++++++++++++++++++++- torch_cluster/rw.py | 29 ++++++++++++++++++++------ 2 files changed, 73 insertions(+), 7 deletions(-) diff --git a/test/test_rw.py b/test/test_rw.py index 76350d6b..0ff91a44 100644 --- a/test/test_rw.py +++ b/test/test_rw.py @@ -6,7 +6,7 @@ @pytest.mark.parametrize('device', devices) -def test_rw(device): +def test_rw_large(device): row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device) col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device) start = tensor([0, 1, 2, 3, 4], torch.long, device) @@ -21,6 +21,9 @@ def test_rw(device): assert out[n, i].item() in col[row == cur].tolist() cur = out[n, i].item() + +@pytest.mark.parametrize('device', devices) +def test_rw_small(device): row = tensor([0, 1], torch.long, device) col = tensor([1, 0], torch.long, device) start = tensor([0, 1, 2], torch.long, device) @@ -28,3 +31,49 @@ def test_rw(device): out = random_walk(row, col, start, walk_length, num_nodes=3) assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]] + + +@pytest.mark.parametrize('device', devices) +def test_rw_large_with_edge_indices(device): + row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device) + col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device) + start = tensor([0, 1, 2, 3, 4], torch.long, device) + walk_length = 10 + + node_seq, edge_seq = random_walk( + row, col, start, walk_length, + return_edge_indices=True, + ) + assert node_seq[:, 0].tolist() == start.tolist() + + for n in range(start.size(0)): + cur = start[n].item() + for i in range(1, walk_length): + assert node_seq[n, i].item() in col[row == cur].tolist() + cur = node_seq[n, i].item() + + assert (edge_seq != -1).all() + + +@pytest.mark.parametrize('device', devices) +def test_rw_small_with_edge_indices(device): + row = tensor([0, 1], torch.long, device) + col = tensor([1, 0], torch.long, device) + start = tensor([0, 1, 2], torch.long, device) + walk_length = 4 + + node_seq, edge_seq = random_walk( + row, col, start, walk_length, + num_nodes=3, + return_edge_indices=True, + ) + assert node_seq.tolist() == [ + [0, 1, 0, 1, 0], + [1, 0, 1, 0, 1], + [2, 2, 2, 2, 2], + ] + assert edge_seq.tolist() == [ + [0, 1, 0, 1], + [1, 0, 1, 0], + [-1, -1, -1, -1], + ] diff --git a/torch_cluster/rw.py b/torch_cluster/rw.py index adb6d6da..12e06837 100644 --- a/torch_cluster/rw.py +++ b/torch_cluster/rw.py @@ -1,13 +1,21 @@ -from typing import Optional +from typing import Optional, Tuple, Union import torch from torch import Tensor @torch.jit.script -def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int, - p: float = 1, q: float = 1, coalesced: bool = True, - num_nodes: Optional[int] = None) -> Tensor: +def random_walk( + row: Tensor, + col: Tensor, + start: Tensor, + walk_length: int, + p: float = 1, + q: float = 1, + coalesced: bool = True, + num_nodes: Optional[int] = None, + return_edge_indices: bool = False, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Samples random walks of length :obj:`walk_length` from all node indices in :obj:`start` in the graph given by :obj:`(row, col)` as described in the `"node2vec: Scalable Feature Learning for Networks" @@ -28,6 +36,9 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int, the graph given by :obj:`(row, col)` according to :obj:`row`. (default: :obj:`True`) num_nodes (int, optional): The number of nodes. (default: :obj:`None`) + return_edge_indices (bool, optional): Whether to additionally return + the indices of edges traversed during the random walk. + (default: :obj:`False`) :rtype: :class:`LongTensor` """ @@ -43,5 +54,11 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int, rowptr = row.new_zeros(num_nodes + 1) torch.cumsum(deg, 0, out=rowptr[1:]) - return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length, - p, q)[0] + node_seq, edge_seq = torch.ops.torch_cluster.random_walk( + rowptr, col, start, walk_length, p, q, + ) + + if return_edge_indices: + return node_seq, edge_seq + + return node_seq