Skip to content

Commit

Permalink
Allow returning edge indices from random walk (#139)
Browse files Browse the repository at this point in the history
This commit adds an optional argument in the `random_walk` function,
namely `return_edge_indices`. The default behaviour is not changed, but
if a user wants to directly use the edges visited by the random walker,
we can return the indices of those edges by setting
`return_edge_indices` to `True`. New cases are also added to the test
suite.
  • Loading branch information
pbielak authored Aug 8, 2022
1 parent c77ed13 commit cc4696b
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 7 deletions.
51 changes: 50 additions & 1 deletion test/test_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


@pytest.mark.parametrize('device', devices)
def test_rw(device):
def test_rw_large(device):
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
start = tensor([0, 1, 2, 3, 4], torch.long, device)
Expand All @@ -21,10 +21,59 @@ def test_rw(device):
assert out[n, i].item() in col[row == cur].tolist()
cur = out[n, i].item()


@pytest.mark.parametrize('device', devices)
def test_rw_small(device):
row = tensor([0, 1], torch.long, device)
col = tensor([1, 0], torch.long, device)
start = tensor([0, 1, 2], torch.long, device)
walk_length = 4

out = random_walk(row, col, start, walk_length, num_nodes=3)
assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]]


@pytest.mark.parametrize('device', devices)
def test_rw_large_with_edge_indices(device):
row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
start = tensor([0, 1, 2, 3, 4], torch.long, device)
walk_length = 10

node_seq, edge_seq = random_walk(
row, col, start, walk_length,
return_edge_indices=True,
)
assert node_seq[:, 0].tolist() == start.tolist()

for n in range(start.size(0)):
cur = start[n].item()
for i in range(1, walk_length):
assert node_seq[n, i].item() in col[row == cur].tolist()
cur = node_seq[n, i].item()

assert (edge_seq != -1).all()


@pytest.mark.parametrize('device', devices)
def test_rw_small_with_edge_indices(device):
row = tensor([0, 1], torch.long, device)
col = tensor([1, 0], torch.long, device)
start = tensor([0, 1, 2], torch.long, device)
walk_length = 4

node_seq, edge_seq = random_walk(
row, col, start, walk_length,
num_nodes=3,
return_edge_indices=True,
)
assert node_seq.tolist() == [
[0, 1, 0, 1, 0],
[1, 0, 1, 0, 1],
[2, 2, 2, 2, 2],
]
assert edge_seq.tolist() == [
[0, 1, 0, 1],
[1, 0, 1, 0],
[-1, -1, -1, -1],
]
29 changes: 23 additions & 6 deletions torch_cluster/rw.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from typing import Optional
from typing import Optional, Tuple, Union

import torch
from torch import Tensor


@torch.jit.script
def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
p: float = 1, q: float = 1, coalesced: bool = True,
num_nodes: Optional[int] = None) -> Tensor:
def random_walk(
row: Tensor,
col: Tensor,
start: Tensor,
walk_length: int,
p: float = 1,
q: float = 1,
coalesced: bool = True,
num_nodes: Optional[int] = None,
return_edge_indices: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Samples random walks of length :obj:`walk_length` from all node indices
in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
`"node2vec: Scalable Feature Learning for Networks"
Expand All @@ -28,6 +36,9 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
the graph given by :obj:`(row, col)` according to :obj:`row`.
(default: :obj:`True`)
num_nodes (int, optional): The number of nodes. (default: :obj:`None`)
return_edge_indices (bool, optional): Whether to additionally return
the indices of edges traversed during the random walk.
(default: :obj:`False`)
:rtype: :class:`LongTensor`
"""
Expand All @@ -43,5 +54,11 @@ def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
rowptr = row.new_zeros(num_nodes + 1)
torch.cumsum(deg, 0, out=rowptr[1:])

return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length,
p, q)[0]
node_seq, edge_seq = torch.ops.torch_cluster.random_walk(
rowptr, col, start, walk_length, p, q,
)

if return_edge_indices:
return node_seq, edge_seq

return node_seq

0 comments on commit cc4696b

Please sign in to comment.