Skip to content

Commit

Permalink
Edge-based temporal sampling (pyg-team#280)
Browse files Browse the repository at this point in the history
This PR is to enable edge-based temporal sampling for both homogeneous
and heterogeneous graphs.
Thanks,
Poovaiah

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
3 people committed Nov 14, 2023
1 parent b585446 commit 2b9af1c
Show file tree
Hide file tree
Showing 11 changed files with 306 additions and 135 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.4.0] - 2023-MM-DD
### Added
- Added support for edge-level sampling ([#280](https://github.com/pyg-team/pyg-lib/pull/280))
- Added support for `bfloat16` data type in `segment_matmul` and `grouped_matmul` (CPU only) ([#272](https://github.com/pyg-team/pyg-lib/pull/272))
### Changed
- Dropped the MKL code path when sampling neighbors with `replace=False` since it does not correctly prevent duplicates ([#275](https://github.com/pyg-team/pyg-lib/pull/275))
Expand Down
1 change: 1 addition & 0 deletions benchmark/sampler/hetero_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def test_hetero_neighbor(dataset, **kwargs):
seed_dict,
num_neighbors_dict,
node_time_dict,
edge_time_dict=None,
seed_time_dict=None,
edge_weight_dict=edge_weight_dict,
csc=True,
Expand Down
3 changes: 2 additions & 1 deletion benchmark/sampler/neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def test_neighbor(dataset, **kwargs):
col,
seed,
num_neighbors,
time=node_time,
node_time=node_time,
edge_time=None,
seed_time=None,
edge_weight=edge_weight,
replace=args.replace,
Expand Down
232 changes: 168 additions & 64 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& node_time,
const c10::optional<at::Tensor>& edge_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
Expand All @@ -38,7 +39,8 @@ hetero_neighbor_sample_kernel(
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& node_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
bool csc,
Expand All @@ -53,7 +55,8 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& node_time,
const c10::optional<at::Tensor>& edge_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
Expand Down
47 changes: 26 additions & 21 deletions pyg_lib/csrc/sampler/neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ neighbor_sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& node_time,
const c10::optional<at::Tensor>& edge_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
Expand All @@ -38,9 +39,9 @@ neighbor_sample(const at::Tensor& rowptr,
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::neighbor_sample", "")
.typed<decltype(neighbor_sample)>();
return op.call(rowptr, col, seed, num_neighbors, time, seed_time, edge_weight,
csc, replace, directed, disjoint, temporal_strategy,
return_edge_id);
return op.call(rowptr, col, seed, num_neighbors, node_time, edge_time,
seed_time, edge_weight, csc, replace, directed, disjoint,
temporal_strategy, return_edge_id);
}

std::tuple<c10::Dict<rel_type, at::Tensor>,
Expand All @@ -56,7 +57,8 @@ hetero_neighbor_sample(
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& node_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
bool csc,
Expand Down Expand Up @@ -89,17 +91,18 @@ hetero_neighbor_sample(
.findSchemaOrThrow("pyg::hetero_neighbor_sample", "")
.typed<decltype(hetero_neighbor_sample)>();
return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict, time_dict, seed_time_dict,
edge_weight_dict, csc, replace, directed, disjoint,
temporal_strategy, return_edge_id);
num_neighbors_dict, node_time_dict, edge_time_dict,
seed_time_dict, edge_weight_dict, csc, replace, directed,
disjoint, temporal_strategy, return_edge_id);
}

std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>> dist_neighbor_sample(
const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& node_time,
const c10::optional<at::Tensor>& edge_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
Expand All @@ -118,34 +121,36 @@ std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>> dist_neighbor_sample(
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::dist_neighbor_sample", "")
.typed<decltype(dist_neighbor_sample)>();
return op.call(rowptr, col, seed, num_neighbors, time, seed_time, edge_weight,
csc, replace, directed, disjoint, temporal_strategy);
return op.call(rowptr, col, seed, num_neighbors, node_time, edge_time,
seed_time, edge_weight, csc, replace, directed, disjoint,
temporal_strategy);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] "
"num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? "
"edge_weight = None, bool csc = False, bool replace = False, bool "
"directed = True, bool disjoint = False, str temporal_strategy = "
"'uniform', bool return_edge_id = True) -> "
"num_neighbors, Tensor? node_time = None, Tensor? edge_time = None, "
"Tensor? seed_time = None, Tensor? edge_weight = None, bool csc = False, "
"bool replace = False, bool directed = True, bool disjoint = False, "
"str temporal_strategy = 'uniform', bool return_edge_id = True) -> "
"(Tensor, Tensor, Tensor, Tensor?, int[], int[])"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::hetero_neighbor_sample(str[] node_types, (str, str, str)[] "
"edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, "
"Dict(str, Tensor) seed_dict, Dict(str, int[]) num_neighbors_dict, "
"Dict(str, Tensor)? time_dict = None, Dict(str, Tensor)? seed_time_dict "
"= None, Dict(str, Tensor)? edge_weight_dict = None, bool csc = False, "
"Dict(str, Tensor)? node_time_dict = None, Dict(str, Tensor)? "
"edge_time_dict = None, Dict(str, Tensor)? seed_time_dict = None, "
"Dict(str, Tensor)? edge_weight_dict = None, bool csc = False, "
"bool replace = False, bool directed = True, bool disjoint = False, "
"str temporal_strategy = 'uniform', bool return_edge_id = True) -> "
"(Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), "
"Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::dist_neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int "
"num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? "
"edge_weight = None, bool csc = False, bool replace = False, bool "
"directed = True, bool disjoint = False, str temporal_strategy = "
"'uniform') -> (Tensor, Tensor, int[])"));
"num_neighbors, Tensor? node_time = None, Tensor? edge_time = None, "
"Tensor? seed_time = None, Tensor? edge_weight = None, bool csc = False, "
"bool replace = False, bool directed = True, bool disjoint = False, "
"str temporal_strategy = 'uniform') -> (Tensor, Tensor, int[])"));
}

} // namespace sampler
Expand Down
10 changes: 7 additions & 3 deletions pyg_lib/csrc/sampler/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ neighbor_sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time = c10::nullopt,
const c10::optional<at::Tensor>& node_time = c10::nullopt,
const c10::optional<at::Tensor>& edge_time = c10::nullopt,
const c10::optional<at::Tensor>& seed_time = c10::nullopt,
const c10::optional<at::Tensor>& edge_weight = c10::nullopt,
bool csc = false,
Expand All @@ -48,7 +49,9 @@ hetero_neighbor_sample(
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict =
const c10::optional<c10::Dict<node_type, at::Tensor>>& node_time_dict =
c10::nullopt,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_time_dict =
c10::nullopt,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict =
c10::nullopt,
Expand All @@ -72,7 +75,8 @@ std::tuple<at::Tensor, at::Tensor, std::vector<int64_t>> dist_neighbor_sample(
const at::Tensor& col,
const at::Tensor& seed,
const int64_t num_neighbors,
const c10::optional<at::Tensor>& time = c10::nullopt,
const c10::optional<at::Tensor>& node_time = c10::nullopt,
const c10::optional<at::Tensor>& edge_time = c10::nullopt,
const c10::optional<at::Tensor>& seed_time = c10::nullopt,
const c10::optional<at::Tensor>& edge_weight = c10::nullopt,
bool csc = false,
Expand Down
59 changes: 31 additions & 28 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def neighbor_sample(
col: Tensor,
seed: Tensor,
num_neighbors: List[int],
time: Optional[Tensor] = None,
node_time: Optional[Tensor] = None,
edge_time: Optional[Tensor] = None,
seed_time: Optional[Tensor] = None,
edge_weight: Optional[Tensor] = None,
csc: bool = False,
Expand All @@ -39,16 +40,27 @@ def neighbor_sample(
num_neighbors (List[int]): The number of neighbors to sample for each
node in each iteration. If an entry is set to :obj:`-1`, all
neighbors will be included.
time (torch.Tensor, optional): Timestamps for the nodes in the graph.
If set, temporal sampling will be used such that neighbors are
guaranteed to fulfill temporal constraints, *i.e.* neighbors have
an earlier or equal timestamp than the seed node.
node_time (torch.Tensor, optional): Timestamps for the nodes in the
graph. If set, temporal sampling will be used such that neighbors
are guaranteed to fulfill temporal constraints, *i.e.* sampled
nodes have an earlier or equal timestamp than the seed node.
If used, the :obj:`col` vector needs to be sorted according to time
within individual neighborhoods. Requires :obj:`disjoint=True`.
Only either :obj:`node_time` or :obj:`edge_time` can be specified.
(default: :obj:`None`)
edge_time (torch.Tensor, optional): Timestamps for the edges in the
graph. If set, temporal sampling will be used such that neighbors
are guaranteed to fulfill temporal constraints, *i.e.* sampled
edges have an earlier or equal timestamp than the seed node.
If used, the :obj:`col` vector needs to be sorted according to time
within individual neighborhoods. Requires :obj:`disjoint=True`.
Only either :obj:`node_time` or :obj:`edge_time` can be specified.
(default: :obj:`None`)
seed_time (torch.Tensor, optional): Optional values to override the
timestamp for seed nodes. If not set, will use timestamps in
:obj:`time` as default for seed nodes. (default: :obj:`None`)
:obj:`node_time` as default for seed nodes.
Needs to be specified in case edge-level sampling is used via
:obj:`edge_time`. (default: :obj:`None`)
edge-weight (torch.Tensor, optional): If given, will perform biased
sampling based on the weight of each edge. (default: :obj:`None`)
csc (bool, optional): If set to :obj:`True`, assumes that the graph is
Expand All @@ -75,18 +87,19 @@ def neighbor_sample(
Lastly, returns information about the sampled amount of nodes and edges
per hop.
"""
return torch.ops.pyg.neighbor_sample(rowptr, col, seed, num_neighbors,
time, seed_time, edge_weight, csc,
replace, directed, disjoint,
temporal_strategy, return_edge_id)
return torch.ops.pyg.neighbor_sample( #
rowptr, col, seed, num_neighbors, node_time, edge_time, seed_time,
edge_weight, csc, replace, directed, disjoint, temporal_strategy,
return_edge_id)


def hetero_neighbor_sample(
rowptr_dict: Dict[EdgeType, Tensor],
col_dict: Dict[EdgeType, Tensor],
seed_dict: Dict[NodeType, Tensor],
num_neighbors_dict: Dict[EdgeType, List[int]],
time_dict: Optional[Dict[NodeType, Tensor]] = None,
node_time_dict: Optional[Dict[NodeType, Tensor]] = None,
edge_time_dict: Optional[Dict[EdgeType, Tensor]] = None,
seed_time_dict: Optional[Dict[NodeType, Tensor]] = None,
edge_weight_dict: Optional[Dict[EdgeType, Tensor]] = None,
csc: bool = False,
Expand Down Expand Up @@ -123,29 +136,19 @@ def hetero_neighbor_sample(
TO_REL_TYPE[k]: v
for k, v in num_neighbors_dict.items()
}
if edge_time_dict is not None:
edge_time_dict = {TO_REL_TYPE[k]: v for k, v in edge_time_dict.items()}
if edge_weight_dict is not None:
edge_weight_dict = {
TO_REL_TYPE[k]: v
for k, v in edge_weight_dict.items()
}

out = torch.ops.pyg.hetero_neighbor_sample(
node_types,
edge_types,
rowptr_dict,
col_dict,
seed_dict,
num_neighbors_dict,
time_dict,
seed_time_dict,
edge_weight_dict,
csc,
replace,
directed,
disjoint,
temporal_strategy,
return_edge_id,
)
out = torch.ops.pyg.hetero_neighbor_sample( #
node_types, edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict, node_time_dict, edge_time_dict, seed_time_dict,
edge_weight_dict, csc, replace, directed, disjoint, temporal_strategy,
return_edge_id)

(row_dict, col_dict, node_id_dict, edge_id_dict, num_nodes_per_hop_dict,
num_edges_per_hop_dict) = out
Expand Down
9 changes: 6 additions & 3 deletions test/csrc/sampler/test_dist_neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ TEST(WithReplacementNeighborTest, BasicAssertions) {
/*col=*/std::get<1>(graph),
/*seed=*/at::arange(2, 4, options),
/*num_neighbors=*/2,
/*time=*/c10::nullopt,
/*node_time=*/c10::nullopt,
/*edge_time=*/c10::nullopt,
/*seed_time=*/c10::nullopt,
/*edge_weight=*/c10::nullopt,
/*csc*/ false,
Expand All @@ -85,7 +86,8 @@ TEST(DistDisjointNeighborTest, BasicAssertions) {
/*col=*/std::get<1>(graph),
/*seed=*/at::arange(2, 4, options),
/*num_neighbors=*/2,
/*time=*/c10::nullopt,
/*node_time=*/c10::nullopt,
/*edge_time=*/c10::nullopt,
/*seed_time=*/c10::nullopt,
/*edge_weight=*/c10::nullopt,
/*csc*/ false,
Expand Down Expand Up @@ -121,7 +123,8 @@ TEST(DistTemporalNeighborTest, BasicAssertions) {
/*col=*/col,
/*seed=*/at::arange(2, 4, options),
/*num_neighbors=*/2,
/*time=*/time,
/*node_time=*/time,
/*edge_time=*/c10::nullopt,
/*seed_time=*/c10::nullopt,
/*edge_weight=*/c10::nullopt,
/*csc*/ false,
Expand Down
9 changes: 6 additions & 3 deletions test/csrc/sampler/test_dist_relabel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ TEST(DistDisjointRelabelNeighborhoodTest, BasicAssertions) {
/*col=*/std::get<1>(graph),
/*seed=*/seed,
/*num_neighbors=*/{2},
/*time=*/c10::nullopt,
/*node_time=*/c10::nullopt,
/*edge_time=*/c10::nullopt,
/*seed_time=*/c10::nullopt,
/*edge_weight=*/c10::nullopt,
/*csc*/ false,
Expand Down Expand Up @@ -182,7 +183,8 @@ TEST(DistHeteroRelabelNeighborhoodCscTest, BasicAssertions) {
/*col_dict=*/col_dict,
/*seed_dict=*/seed_dict,
/*num_neighbors_dict=*/num_neighbors_dict,
/*time_dict=*/c10::nullopt,
/*node_time_dict=*/c10::nullopt,
/*edge_time_dict=*/c10::nullopt,
/*seed_time_dict=*/c10::nullopt,
/*edge_weight_dict=*/c10::nullopt,
/*csc=*/true);
Expand Down Expand Up @@ -246,7 +248,8 @@ TEST(DistHeteroDisjointRelabelNeighborhoodTest, BasicAssertions) {
/*col_dict=*/col_dict,
/*seed_dict=*/seed_dict,
/*num_neighbors_dict=*/num_neighbors_dict,
/*time_dict=*/c10::nullopt,
/*node_time_dict=*/c10::nullopt,
/*edge_time_dict=*/c10::nullopt,
/*seed_time_dict=*/c10::nullopt,
/*edge_weight_dict=*/c10::nullopt,
/*csc=*/false,
Expand Down
Loading

0 comments on commit 2b9af1c

Please sign in to comment.