diff --git a/CHANGELOG.md b/CHANGELOG.md index 85469f4d0..b7404fe5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [0.4.0] - 2023-MM-DD ### Added +- Added support for edge-level sampling ([#280](https://github.com/pyg-team/pyg-lib/pull/280)) - Added support for `bfloat16` data type in `segment_matmul` and `grouped_matmul` (CPU only) ([#272](https://github.com/pyg-team/pyg-lib/pull/272)) ### Changed - Dropped the MKL code path when sampling neighbors with `replace=False` since it does not correctly prevent duplicates ([#275](https://github.com/pyg-team/pyg-lib/pull/275)) diff --git a/benchmark/sampler/hetero_neighbor.py b/benchmark/sampler/hetero_neighbor.py index 26d3972e3..dbcdcb785 100644 --- a/benchmark/sampler/hetero_neighbor.py +++ b/benchmark/sampler/hetero_neighbor.py @@ -94,6 +94,7 @@ def test_hetero_neighbor(dataset, **kwargs): seed_dict, num_neighbors_dict, node_time_dict, + edge_time_dict=None, seed_time_dict=None, edge_weight_dict=edge_weight_dict, csc=True, diff --git a/benchmark/sampler/neighbor.py b/benchmark/sampler/neighbor.py index 7c6bd9a42..c9c2ed851 100644 --- a/benchmark/sampler/neighbor.py +++ b/benchmark/sampler/neighbor.py @@ -86,7 +86,8 @@ def test_neighbor(dataset, **kwargs): col, seed, num_neighbors, - time=node_time, + node_time=node_time, + edge_time=None, seed_time=None, edge_weight=edge_weight, replace=args.replace, diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 307f9fe82..1068359bc 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -72,14 +72,14 @@ class NeighborSampler { dst_mapper, generator, out_global_dst_nodes); } - void temporal_sample(const node_t global_src_node, - const scalar_t local_src_node, - const int64_t count, - const temporal_t seed_time, - const temporal_t* time, - pyg::sampler::Mapper& dst_mapper, - pyg::random::RandintEngine& generator, - std::vector& out_global_dst_nodes) { + void node_temporal_sample(const node_t global_src_node, + const scalar_t local_src_node, + const int64_t count, + const temporal_t seed_time, + const temporal_t* time, + pyg::sampler::Mapper& dst_mapper, + pyg::random::RandintEngine& generator, + std::vector& out_global_dst_nodes) { auto row_start = rowptr_[to_scalar_t(global_src_node)]; auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; @@ -105,6 +105,38 @@ class NeighborSampler { dst_mapper, generator, out_global_dst_nodes); } + void edge_temporal_sample(const node_t global_src_node, + const scalar_t local_src_node, + const int64_t count, + const temporal_t seed_time, + const temporal_t* time, + pyg::sampler::Mapper& dst_mapper, + pyg::random::RandintEngine& generator, + std::vector& out_global_dst_nodes) { + auto row_start = rowptr_[to_scalar_t(global_src_node)]; + auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; + + if ((row_end - row_start == 0) || (count == 0)) + return; + + // Find new `row_end` such that all neighbors fulfill temporal constraints: + auto it = std::upper_bound( + time + row_start, time + row_end, seed_time, + [&](const scalar_t& a, const scalar_t& b) { return a < b; }); + row_end = it - time; + + if (temporal_strategy_ == "last" && count >= 0) { + row_start = std::max(row_start, (scalar_t)(row_end - count)); + } + if (row_end - row_start > 1) { + TORCH_CHECK(time[row_start] <= time[row_end - 1], + "Found invalid non-sorted temporal neighborhood"); + } + + _sample(global_src_node, local_src_node, row_start, row_end, count, + dst_mapper, generator, out_global_dst_nodes); + } + std::tuple> get_sampled_edges(bool csc = false) { TORCH_CHECK(save_edges, "No edges have been stored") @@ -307,26 +339,39 @@ sample(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, const std::vector& num_neighbors, - const c10::optional& time, + const c10::optional& node_time, + const c10::optional& edge_time, const c10::optional& seed_time, const c10::optional& edge_weight, const bool csc, const std::string temporal_strategy) { - TORCH_CHECK(!time.has_value() || disjoint, + TORCH_CHECK(!node_time.has_value() || disjoint, "Temporal sampling needs to create disjoint subgraphs"); + TORCH_CHECK(!edge_time.has_value() || disjoint, + "Temporal sampling needs to create disjoint subgraphs"); + TORCH_CHECK(!(node_time.has_value() && edge_time.has_value()), + "Only one of node-level or edge-level sampling is supported "); TORCH_CHECK(rowptr.is_contiguous(), "Non-contiguous 'rowptr'"); TORCH_CHECK(col.is_contiguous(), "Non-contiguous 'col'"); TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed'"); - if (time.has_value()) { - TORCH_CHECK(time.value().is_contiguous(), "Non-contiguous 'time'"); + if (node_time.has_value()) { + TORCH_CHECK(node_time.value().is_contiguous(), + "Non-contiguous 'node_time'"); + } + if (edge_time.has_value()) { + TORCH_CHECK(edge_time.value().is_contiguous(), + "Non-contiguous 'edge_time'"); + TORCH_CHECK(seed_time.has_value(), "Seed time needs to be specified"); } if (seed_time.has_value()) { TORCH_CHECK(seed_time.value().is_contiguous(), "Non-contiguous 'seed_time'"); } - TORCH_CHECK(!(time.has_value() && edge_weight.has_value()), - "Biased temporal sampling not yet supported"); + TORCH_CHECK(!(node_time.has_value() && edge_weight.has_value()), + "Biased node temporal sampling not yet supported"); + TORCH_CHECK(!(edge_time.has_value() && edge_weight.has_value()), + "Biased edge temporal sampling not yet supported"); at::Tensor out_row, out_col, out_node_id; c10::optional out_edge_id = c10::nullopt; @@ -368,8 +413,8 @@ sample(const at::Tensor& rowptr, for (size_t i = 0; i < seed.numel(); ++i) { seed_times.push_back(seed_time_data[i]); } - } else if (time.has_value()) { - const auto time_data = time.value().data_ptr(); + } else if (node_time.has_value()) { + const auto time_data = node_time.value().data_ptr(); for (size_t i = 0; i < seed.numel(); ++i) { seed_times.push_back(time_data[seed_data[i]]); } @@ -395,7 +440,7 @@ sample(const at::Tensor& rowptr, if constexpr (distributed) cumsum_neighbors_per_node.push_back(sampled_nodes.size()); } - } else if (!time.has_value()) { + } else if (!node_time.has_value() && !edge_time.has_value()) { for (size_t i = begin; i < end; ++i) { sampler.uniform_sample( /*global_src_node=*/sampled_nodes[i], @@ -408,19 +453,38 @@ sample(const at::Tensor& rowptr, cumsum_neighbors_per_node.push_back(sampled_nodes.size()); } } else if constexpr (!std::is_scalar::value) { // Temporal: - const auto time_data = time.value().data_ptr(); - for (size_t i = begin; i < end; ++i) { - const auto batch_idx = sampled_nodes[i].first; - sampler.temporal_sample( - /*global_src_node=*/sampled_nodes[i], - /*local_src_node=*/i, /*count=*/count, - /*seed_time=*/seed_times[batch_idx], - /*time=*/time_data, - /*dst_mapper=*/mapper, - /*generator=*/generator, - /*out_global_dst_nodes=*/sampled_nodes); - if constexpr (distributed) - cumsum_neighbors_per_node.push_back(sampled_nodes.size()); + if (edge_time.has_value()) { + const auto edge_time_data = edge_time.value().data_ptr(); + for (size_t i = begin; i < end; ++i) { + const auto batch_idx = sampled_nodes[i].first; + sampler.edge_temporal_sample( + /*global_src_node=*/sampled_nodes[i], + /*local_src_node=*/i, + /*count=*/count, + /*seed_time=*/seed_times[batch_idx], + /*time=*/edge_time_data, + /*dst_mapper=*/mapper, + /*generator=*/generator, + /*out_global_dst_nodes=*/sampled_nodes); + if constexpr (distributed) + cumsum_neighbors_per_node.push_back(sampled_nodes.size()); + } + } else { + const auto node_time_data = node_time.value().data_ptr(); + for (size_t i = begin; i < end; ++i) { + const auto batch_idx = sampled_nodes[i].first; + sampler.node_temporal_sample( + /*global_src_node=*/sampled_nodes[i], + /*local_src_node=*/i, + /*count=*/count, + /*seed_time=*/seed_times[batch_idx], + /*time=*/node_time_data, + /*dst_mapper=*/mapper, + /*generator=*/generator, + /*out_global_dst_nodes=*/sampled_nodes); + if constexpr (distributed) + cumsum_neighbors_per_node.push_back(sampled_nodes.size()); + } } } begin = end, end = sampled_nodes.size(); @@ -462,13 +526,18 @@ sample(const std::vector& node_types, const c10::Dict& col_dict, const c10::Dict& seed_dict, const c10::Dict>& num_neighbors_dict, - const c10::optional>& time_dict, + const c10::optional>& node_time_dict, + const c10::optional>& edge_time_dict, const c10::optional>& seed_time_dict, const c10::optional>& edge_weight_dict, const bool csc, const std::string temporal_strategy) { - TORCH_CHECK(!time_dict.has_value() || disjoint, - "Temporal sampling needs to create disjoint subgraphs"); + TORCH_CHECK(!node_time_dict.has_value() || disjoint, + "Node temporal sampling needs to create disjoint subgraphs"); + TORCH_CHECK(!edge_time_dict.has_value() || disjoint, + "Edge temporal sampling needs to create disjoint subgraphs"); + TORCH_CHECK(!(node_time_dict.has_value() && edge_time_dict.has_value()), + "Only one of node-level or edge-level sampling is supported "); for (const auto& kv : rowptr_dict) { const at::Tensor& rowptr = kv.value(); @@ -482,19 +551,28 @@ sample(const std::vector& node_types, const at::Tensor& seed = kv.value(); TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed'"); } - if (time_dict.has_value()) { - for (const auto& kv : time_dict.value()) { + if (node_time_dict.has_value()) { + for (const auto& kv : node_time_dict.value()) { const at::Tensor& time = kv.value(); - TORCH_CHECK(time.is_contiguous(), "Non-contiguous 'time'"); + TORCH_CHECK(time.is_contiguous(), "Non-contiguous 'node_time'"); } } + if (edge_time_dict.has_value()) { + for (const auto& kv : edge_time_dict.value()) { + const at::Tensor& time = kv.value(); + TORCH_CHECK(time.is_contiguous(), "Non-contiguous 'edge_time'"); + } + TORCH_CHECK(seed_time_dict.has_value(), "Seed time needs to be specified"); + } if (seed_time_dict.has_value()) { for (const auto& kv : seed_time_dict.value()) { const at::Tensor& seed_time = kv.value(); TORCH_CHECK(seed_time.is_contiguous(), "Non-contiguous 'seed_time'"); } } - TORCH_CHECK(!(time_dict.has_value() && edge_weight_dict.has_value()), + TORCH_CHECK(!(node_time_dict.has_value() && edge_weight_dict.has_value()), + "Biased temporal sampling not yet supported"); + TORCH_CHECK(!(edge_time_dict.has_value() && edge_weight_dict.has_value()), "Biased temporal sampling not yet supported"); c10::Dict out_row_dict, out_col_dict; @@ -604,8 +682,8 @@ sample(const std::vector& node_types, for (size_t i = 0; i < seed.numel(); ++i) { seed_times.push_back(seed_time_data[i]); } - } else if (time_dict.has_value()) { - const at::Tensor& time = time_dict.value().at(kv.key()); + } else if (node_time_dict.has_value()) { + const at::Tensor& time = node_time_dict.value().at(kv.key()); const auto time_data = time.data_ptr(); seed_times.reserve(seed_times.size() + seed.numel()); for (size_t i = 0; i < seed.numel(); ++i) { @@ -659,8 +737,10 @@ sample(const std::vector& node_types, /*generator=*/generator, /*out_global_dst_nodes=*/dst_sampled_nodes); } - } else if (!time_dict.has_value() || - !time_dict.value().contains(dst)) { + } else if ((!node_time_dict.has_value() || + !node_time_dict.value().contains(dst)) && + (!edge_time_dict.has_value() || + !edge_time_dict.value().contains(to_rel_type(k)))) { for (size_t i = begin; i < end; ++i) { sampler.uniform_sample( /*global_src_node=*/src_sampled_nodes[i], @@ -671,20 +751,41 @@ sample(const std::vector& node_types, /*out_global_dst_nodes=*/dst_sampled_nodes); } } else if constexpr (!std::is_scalar::value) { - // Temporal sampling: - const at::Tensor& dst_time = time_dict.value().at(dst); - const auto dst_time_data = dst_time.data_ptr(); - for (size_t i = begin; i < end; ++i) { - const auto batch_idx = src_sampled_nodes[i].first; - sampler.temporal_sample( - /*global_src_node=*/src_sampled_nodes[i], - /*local_src_node=*/i, - /*count=*/count, - /*seed_time=*/seed_times[batch_idx], - /*time=*/dst_time_data, - /*dst_mapper=*/dst_mapper, - /*generator=*/generator, - /*out_global_dst_nodes=*/dst_sampled_nodes); + if (edge_time_dict.has_value() && + edge_time_dict.value().contains(to_rel_type(k))) { + // Edge-level temporal sampling: + const at::Tensor& edge_time = + edge_time_dict.value().at(to_rel_type(k)); + const auto edge_time_data = + edge_time.data_ptr(); + for (size_t i = begin; i < end; ++i) { + const auto batch_idx = src_sampled_nodes[i].first; + sampler.edge_temporal_sample( + /*global_src_node=*/src_sampled_nodes[i], + /*local_src_node=*/i, + /*count=*/count, + /*seed_time=*/seed_times[batch_idx], + /*time=*/edge_time_data, + /*dst_mapper=*/dst_mapper, + /*generator=*/generator, + /*out_global_dst_nodes=*/dst_sampled_nodes); + } + } else { + // Node-level temporal sampling: + const at::Tensor& dst_time = node_time_dict.value().at(dst); + const auto dst_time_data = dst_time.data_ptr(); + for (size_t i = begin; i < end; ++i) { + const auto batch_idx = src_sampled_nodes[i].first; + sampler.node_temporal_sample( + /*global_src_node=*/src_sampled_nodes[i], + /*local_src_node=*/i, + /*count=*/count, + /*seed_time=*/seed_times[batch_idx], + /*time=*/dst_time_data, + /*dst_mapper=*/dst_mapper, + /*generator=*/generator, + /*out_global_dst_nodes=*/dst_sampled_nodes); + } } } } @@ -799,7 +900,8 @@ neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, const std::vector& num_neighbors, - const c10::optional& time, + const c10::optional& node_time, + const c10::optional& edge_time, const c10::optional& seed_time, const c10::optional& edge_weight, bool csc, @@ -810,8 +912,8 @@ neighbor_sample_kernel(const at::Tensor& rowptr, bool return_edge_id) { const auto out = [&] { DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col, - seed, num_neighbors, time, seed_time, edge_weight, csc, - temporal_strategy); + seed, num_neighbors, node_time, edge_time, seed_time, + edge_weight, csc, temporal_strategy); }(); return std::make_tuple(std::get<0>(out), std::get<1>(out), std::get<2>(out), std::get<3>(out), std::get<4>(out), std::get<5>(out)); @@ -830,7 +932,8 @@ hetero_neighbor_sample_kernel( const c10::Dict& col_dict, const c10::Dict& seed_dict, const c10::Dict>& num_neighbors_dict, - const c10::optional>& time_dict, + const c10::optional>& node_time_dict, + const c10::optional>& edge_time_dict, const c10::optional>& seed_time_dict, const c10::optional>& edge_weight_dict, bool csc, @@ -841,8 +944,8 @@ hetero_neighbor_sample_kernel( bool return_edge_id) { DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, node_types, edge_types, rowptr_dict, col_dict, seed_dict, - num_neighbors_dict, time_dict, seed_time_dict, - edge_weight_dict, csc, temporal_strategy); + num_neighbors_dict, node_time_dict, edge_time_dict, + seed_time_dict, edge_weight_dict, csc, temporal_strategy); } std::tuple> @@ -850,7 +953,8 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, const int64_t num_neighbors, - const c10::optional& time, + const c10::optional& node_time, + const c10::optional& edge_time, const c10::optional& seed_time, const c10::optional& edge_weight, bool csc, @@ -860,8 +964,8 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr, std::string temporal_strategy) { const auto out = [&] { DISPATCH_DIST_SAMPLE(replace, directed, disjoint, rowptr, col, seed, - {num_neighbors}, time, seed_time, edge_weight, csc, - temporal_strategy); + {num_neighbors}, node_time, edge_time, seed_time, + edge_weight, csc, temporal_strategy); }(); return std::make_tuple(std::get<2>(out), std::get<3>(out).value(), std::get<6>(out)); diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h index 2d1dfa831..d5ae13015 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h @@ -15,7 +15,8 @@ neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, const std::vector& num_neighbors, - const c10::optional& time, + const c10::optional& node_time, + const c10::optional& edge_time, const c10::optional& seed_time, const c10::optional& edge_weight, bool csc, @@ -38,7 +39,8 @@ hetero_neighbor_sample_kernel( const c10::Dict& col_dict, const c10::Dict& seed_dict, const c10::Dict>& num_neighbors_dict, - const c10::optional>& time_dict, + const c10::optional>& node_time_dict, + const c10::optional>& edge_time_dict, const c10::optional>& seed_time_dict, const c10::optional>& edge_weight_dict, bool csc, @@ -53,7 +55,8 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, const int64_t num_neighbors, - const c10::optional& time, + const c10::optional& node_time, + const c10::optional& edge_time, const c10::optional& seed_time, const c10::optional& edge_weight, bool csc, diff --git a/pyg_lib/csrc/sampler/neighbor.cpp b/pyg_lib/csrc/sampler/neighbor.cpp index 38c0cc3cb..a3304f2f8 100644 --- a/pyg_lib/csrc/sampler/neighbor.cpp +++ b/pyg_lib/csrc/sampler/neighbor.cpp @@ -18,7 +18,8 @@ neighbor_sample(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, const std::vector& num_neighbors, - const c10::optional& time, + const c10::optional& node_time, + const c10::optional& edge_time, const c10::optional& seed_time, const c10::optional& edge_weight, bool csc, @@ -38,9 +39,9 @@ neighbor_sample(const at::Tensor& rowptr, static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::neighbor_sample", "") .typed(); - return op.call(rowptr, col, seed, num_neighbors, time, seed_time, edge_weight, - csc, replace, directed, disjoint, temporal_strategy, - return_edge_id); + return op.call(rowptr, col, seed, num_neighbors, node_time, edge_time, + seed_time, edge_weight, csc, replace, directed, disjoint, + temporal_strategy, return_edge_id); } std::tuple, @@ -56,7 +57,8 @@ hetero_neighbor_sample( const c10::Dict& col_dict, const c10::Dict& seed_dict, const c10::Dict>& num_neighbors_dict, - const c10::optional>& time_dict, + const c10::optional>& node_time_dict, + const c10::optional>& edge_time_dict, const c10::optional>& seed_time_dict, const c10::optional>& edge_weight_dict, bool csc, @@ -89,9 +91,9 @@ hetero_neighbor_sample( .findSchemaOrThrow("pyg::hetero_neighbor_sample", "") .typed(); return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict, - num_neighbors_dict, time_dict, seed_time_dict, - edge_weight_dict, csc, replace, directed, disjoint, - temporal_strategy, return_edge_id); + num_neighbors_dict, node_time_dict, edge_time_dict, + seed_time_dict, edge_weight_dict, csc, replace, directed, + disjoint, temporal_strategy, return_edge_id); } std::tuple> dist_neighbor_sample( @@ -99,7 +101,8 @@ std::tuple> dist_neighbor_sample( const at::Tensor& col, const at::Tensor& seed, const int64_t num_neighbors, - const c10::optional& time, + const c10::optional& node_time, + const c10::optional& edge_time, const c10::optional& seed_time, const c10::optional& edge_weight, bool csc, @@ -118,34 +121,36 @@ std::tuple> dist_neighbor_sample( static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::dist_neighbor_sample", "") .typed(); - return op.call(rowptr, col, seed, num_neighbors, time, seed_time, edge_weight, - csc, replace, directed, disjoint, temporal_strategy); + return op.call(rowptr, col, seed, num_neighbors, node_time, edge_time, + seed_time, edge_weight, csc, replace, directed, disjoint, + temporal_strategy); } TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( "pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] " - "num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? " - "edge_weight = None, bool csc = False, bool replace = False, bool " - "directed = True, bool disjoint = False, str temporal_strategy = " - "'uniform', bool return_edge_id = True) -> " + "num_neighbors, Tensor? node_time = None, Tensor? edge_time = None, " + "Tensor? seed_time = None, Tensor? edge_weight = None, bool csc = False, " + "bool replace = False, bool directed = True, bool disjoint = False, " + "str temporal_strategy = 'uniform', bool return_edge_id = True) -> " "(Tensor, Tensor, Tensor, Tensor?, int[], int[])")); m.def(TORCH_SELECTIVE_SCHEMA( "pyg::hetero_neighbor_sample(str[] node_types, (str, str, str)[] " "edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, " "Dict(str, Tensor) seed_dict, Dict(str, int[]) num_neighbors_dict, " - "Dict(str, Tensor)? time_dict = None, Dict(str, Tensor)? seed_time_dict " - "= None, Dict(str, Tensor)? edge_weight_dict = None, bool csc = False, " + "Dict(str, Tensor)? node_time_dict = None, Dict(str, Tensor)? " + "edge_time_dict = None, Dict(str, Tensor)? seed_time_dict = None, " + "Dict(str, Tensor)? edge_weight_dict = None, bool csc = False, " "bool replace = False, bool directed = True, bool disjoint = False, " "str temporal_strategy = 'uniform', bool return_edge_id = True) -> " "(Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), " "Dict(str, Tensor)?, Dict(str, int[]), Dict(str, int[]))")); m.def(TORCH_SELECTIVE_SCHEMA( "pyg::dist_neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int " - "num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? " - "edge_weight = None, bool csc = False, bool replace = False, bool " - "directed = True, bool disjoint = False, str temporal_strategy = " - "'uniform') -> (Tensor, Tensor, int[])")); + "num_neighbors, Tensor? node_time = None, Tensor? edge_time = None, " + "Tensor? seed_time = None, Tensor? edge_weight = None, bool csc = False, " + "bool replace = False, bool directed = True, bool disjoint = False, " + "str temporal_strategy = 'uniform') -> (Tensor, Tensor, int[])")); } } // namespace sampler diff --git a/pyg_lib/csrc/sampler/neighbor.h b/pyg_lib/csrc/sampler/neighbor.h index 66282e38d..6d6ab2c75 100644 --- a/pyg_lib/csrc/sampler/neighbor.h +++ b/pyg_lib/csrc/sampler/neighbor.h @@ -21,7 +21,8 @@ neighbor_sample(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& seed, const std::vector& num_neighbors, - const c10::optional& time = c10::nullopt, + const c10::optional& node_time = c10::nullopt, + const c10::optional& edge_time = c10::nullopt, const c10::optional& seed_time = c10::nullopt, const c10::optional& edge_weight = c10::nullopt, bool csc = false, @@ -48,7 +49,9 @@ hetero_neighbor_sample( const c10::Dict& col_dict, const c10::Dict& seed_dict, const c10::Dict>& num_neighbors_dict, - const c10::optional>& time_dict = + const c10::optional>& node_time_dict = + c10::nullopt, + const c10::optional>& edge_time_dict = c10::nullopt, const c10::optional>& seed_time_dict = c10::nullopt, @@ -72,7 +75,8 @@ std::tuple> dist_neighbor_sample( const at::Tensor& col, const at::Tensor& seed, const int64_t num_neighbors, - const c10::optional& time = c10::nullopt, + const c10::optional& node_time = c10::nullopt, + const c10::optional& edge_time = c10::nullopt, const c10::optional& seed_time = c10::nullopt, const c10::optional& edge_weight = c10::nullopt, bool csc = false, diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index dd0cc1a0f..c2ca0c496 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -13,7 +13,8 @@ def neighbor_sample( col: Tensor, seed: Tensor, num_neighbors: List[int], - time: Optional[Tensor] = None, + node_time: Optional[Tensor] = None, + edge_time: Optional[Tensor] = None, seed_time: Optional[Tensor] = None, edge_weight: Optional[Tensor] = None, csc: bool = False, @@ -39,16 +40,27 @@ def neighbor_sample( num_neighbors (List[int]): The number of neighbors to sample for each node in each iteration. If an entry is set to :obj:`-1`, all neighbors will be included. - time (torch.Tensor, optional): Timestamps for the nodes in the graph. - If set, temporal sampling will be used such that neighbors are - guaranteed to fulfill temporal constraints, *i.e.* neighbors have - an earlier or equal timestamp than the seed node. + node_time (torch.Tensor, optional): Timestamps for the nodes in the + graph. If set, temporal sampling will be used such that neighbors + are guaranteed to fulfill temporal constraints, *i.e.* sampled + nodes have an earlier or equal timestamp than the seed node. If used, the :obj:`col` vector needs to be sorted according to time within individual neighborhoods. Requires :obj:`disjoint=True`. + Only either :obj:`node_time` or :obj:`edge_time` can be specified. + (default: :obj:`None`) + edge_time (torch.Tensor, optional): Timestamps for the edges in the + graph. If set, temporal sampling will be used such that neighbors + are guaranteed to fulfill temporal constraints, *i.e.* sampled + edges have an earlier or equal timestamp than the seed node. + If used, the :obj:`col` vector needs to be sorted according to time + within individual neighborhoods. Requires :obj:`disjoint=True`. + Only either :obj:`node_time` or :obj:`edge_time` can be specified. (default: :obj:`None`) seed_time (torch.Tensor, optional): Optional values to override the timestamp for seed nodes. If not set, will use timestamps in - :obj:`time` as default for seed nodes. (default: :obj:`None`) + :obj:`node_time` as default for seed nodes. + Needs to be specified in case edge-level sampling is used via + :obj:`edge_time`. (default: :obj:`None`) edge-weight (torch.Tensor, optional): If given, will perform biased sampling based on the weight of each edge. (default: :obj:`None`) csc (bool, optional): If set to :obj:`True`, assumes that the graph is @@ -75,10 +87,10 @@ def neighbor_sample( Lastly, returns information about the sampled amount of nodes and edges per hop. """ - return torch.ops.pyg.neighbor_sample(rowptr, col, seed, num_neighbors, - time, seed_time, edge_weight, csc, - replace, directed, disjoint, - temporal_strategy, return_edge_id) + return torch.ops.pyg.neighbor_sample( # + rowptr, col, seed, num_neighbors, node_time, edge_time, seed_time, + edge_weight, csc, replace, directed, disjoint, temporal_strategy, + return_edge_id) def hetero_neighbor_sample( @@ -86,7 +98,8 @@ def hetero_neighbor_sample( col_dict: Dict[EdgeType, Tensor], seed_dict: Dict[NodeType, Tensor], num_neighbors_dict: Dict[EdgeType, List[int]], - time_dict: Optional[Dict[NodeType, Tensor]] = None, + node_time_dict: Optional[Dict[NodeType, Tensor]] = None, + edge_time_dict: Optional[Dict[EdgeType, Tensor]] = None, seed_time_dict: Optional[Dict[NodeType, Tensor]] = None, edge_weight_dict: Optional[Dict[EdgeType, Tensor]] = None, csc: bool = False, @@ -123,29 +136,19 @@ def hetero_neighbor_sample( TO_REL_TYPE[k]: v for k, v in num_neighbors_dict.items() } + if edge_time_dict is not None: + edge_time_dict = {TO_REL_TYPE[k]: v for k, v in edge_time_dict.items()} if edge_weight_dict is not None: edge_weight_dict = { TO_REL_TYPE[k]: v for k, v in edge_weight_dict.items() } - out = torch.ops.pyg.hetero_neighbor_sample( - node_types, - edge_types, - rowptr_dict, - col_dict, - seed_dict, - num_neighbors_dict, - time_dict, - seed_time_dict, - edge_weight_dict, - csc, - replace, - directed, - disjoint, - temporal_strategy, - return_edge_id, - ) + out = torch.ops.pyg.hetero_neighbor_sample( # + node_types, edge_types, rowptr_dict, col_dict, seed_dict, + num_neighbors_dict, node_time_dict, edge_time_dict, seed_time_dict, + edge_weight_dict, csc, replace, directed, disjoint, temporal_strategy, + return_edge_id) (row_dict, col_dict, node_id_dict, edge_id_dict, num_nodes_per_hop_dict, num_edges_per_hop_dict) = out diff --git a/test/csrc/sampler/test_dist_neighbor.cpp b/test/csrc/sampler/test_dist_neighbor.cpp index f13d1c668..ee93b5897 100644 --- a/test/csrc/sampler/test_dist_neighbor.cpp +++ b/test/csrc/sampler/test_dist_neighbor.cpp @@ -59,7 +59,8 @@ TEST(WithReplacementNeighborTest, BasicAssertions) { /*col=*/std::get<1>(graph), /*seed=*/at::arange(2, 4, options), /*num_neighbors=*/2, - /*time=*/c10::nullopt, + /*node_time=*/c10::nullopt, + /*edge_time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*edge_weight=*/c10::nullopt, /*csc*/ false, @@ -85,7 +86,8 @@ TEST(DistDisjointNeighborTest, BasicAssertions) { /*col=*/std::get<1>(graph), /*seed=*/at::arange(2, 4, options), /*num_neighbors=*/2, - /*time=*/c10::nullopt, + /*node_time=*/c10::nullopt, + /*edge_time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*edge_weight=*/c10::nullopt, /*csc*/ false, @@ -121,7 +123,8 @@ TEST(DistTemporalNeighborTest, BasicAssertions) { /*col=*/col, /*seed=*/at::arange(2, 4, options), /*num_neighbors=*/2, - /*time=*/time, + /*node_time=*/time, + /*edge_time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*edge_weight=*/c10::nullopt, /*csc*/ false, diff --git a/test/csrc/sampler/test_dist_relabel.cpp b/test/csrc/sampler/test_dist_relabel.cpp index 09fefd52f..6c8389363 100644 --- a/test/csrc/sampler/test_dist_relabel.cpp +++ b/test/csrc/sampler/test_dist_relabel.cpp @@ -65,7 +65,8 @@ TEST(DistDisjointRelabelNeighborhoodTest, BasicAssertions) { /*col=*/std::get<1>(graph), /*seed=*/seed, /*num_neighbors=*/{2}, - /*time=*/c10::nullopt, + /*node_time=*/c10::nullopt, + /*edge_time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*edge_weight=*/c10::nullopt, /*csc*/ false, @@ -182,7 +183,8 @@ TEST(DistHeteroRelabelNeighborhoodCscTest, BasicAssertions) { /*col_dict=*/col_dict, /*seed_dict=*/seed_dict, /*num_neighbors_dict=*/num_neighbors_dict, - /*time_dict=*/c10::nullopt, + /*node_time_dict=*/c10::nullopt, + /*edge_time_dict=*/c10::nullopt, /*seed_time_dict=*/c10::nullopt, /*edge_weight_dict=*/c10::nullopt, /*csc=*/true); @@ -246,7 +248,8 @@ TEST(DistHeteroDisjointRelabelNeighborhoodTest, BasicAssertions) { /*col_dict=*/col_dict, /*seed_dict=*/seed_dict, /*num_neighbors_dict=*/num_neighbors_dict, - /*time_dict=*/c10::nullopt, + /*node_time_dict=*/c10::nullopt, + /*edge_time_dict=*/c10::nullopt, /*seed_time_dict=*/c10::nullopt, /*edge_weight_dict=*/c10::nullopt, /*csc=*/false, diff --git a/test/csrc/sampler/test_neighbor.cpp b/test/csrc/sampler/test_neighbor.cpp index 20c94d18b..7f501641c 100644 --- a/test/csrc/sampler/test_neighbor.cpp +++ b/test/csrc/sampler/test_neighbor.cpp @@ -41,7 +41,8 @@ TEST(WithoutReplacementNeighborTest, BasicAssertions) { /*col=*/std::get<1>(graph), /*seed=*/at ::arange(2, 4, options), /*num_neighbors=*/{1, 1}, - /*time=*/c10::nullopt, + /*node_time=*/c10::nullopt, + /*edge_time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*edge_weight=*/c10::nullopt, /*csc=*/false, @@ -68,7 +69,8 @@ TEST(WithReplacementNeighborTest, BasicAssertions) { /*col=*/std::get<1>(graph), /*seed=*/at::arange(2, 4, options), /*num_neighbors=*/{1, 1}, - /*time=*/c10::nullopt, + /*node_time=*/c10::nullopt, + /*edge_time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*edge_weight=*/c10::nullopt, /*csc=*/false, @@ -94,7 +96,8 @@ TEST(DisjointNeighborTest, BasicAssertions) { /*col=*/std::get<1>(graph), /*seed=*/at::arange(2, 4, options), /*num_neighbors=*/{2, 2}, - /*time=*/c10::nullopt, + /*node_time=*/c10::nullopt, + /*edge_time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*edge_weight=*/c10::nullopt, /*csc=*/false, @@ -114,7 +117,7 @@ TEST(DisjointNeighborTest, BasicAssertions) { EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); } -TEST(TemporalNeighborTest, BasicAssertions) { +TEST(NodeLevelTemporalNeighborTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); auto graph = cycle_graph(/*num_nodes=*/6, options); @@ -122,7 +125,7 @@ TEST(TemporalNeighborTest, BasicAssertions) { auto col = std::get<1>(graph); // Time is equal to node ID ... - auto time = at::arange(6, options); + auto node_time = at::arange(6, options); // ... so we need to sort the column vector by time/node ID: col = std::get<0>(at::sort(col.view({-1, 2}), /*dim=*/1)).flatten(); @@ -131,7 +134,8 @@ TEST(TemporalNeighborTest, BasicAssertions) { /*col=*/col, /*seed=*/at::arange(2, 4, options), /*num_neighbors=*/{2, 2}, - /*time=*/time, + /*node_time=*/node_time, + /*edge_time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*edge_weight=*/c10::nullopt, /*csc=*/false, @@ -155,7 +159,8 @@ TEST(TemporalNeighborTest, BasicAssertions) { /*col=*/col, /*seed=*/at::arange(2, 4, options), /*num_neighbors=*/{1, 2}, - /*time=*/time, + /*node_time=*/node_time, + /*edge_time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*edge_weight=*/c10::nullopt, /*csc=*/false, @@ -170,6 +175,42 @@ TEST(TemporalNeighborTest, BasicAssertions) { EXPECT_TRUE(at::equal(std::get<3>(out1).value(), std::get<3>(out2).value())); } +TEST(EdgeLevelTemporalNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + auto graph = cycle_graph(/*num_nodes=*/6, options); + auto rowptr = std::get<0>(graph); + auto col = std::get<1>(graph); + + // Time is equal to edge ID: + auto edge_time = at::arange(col.numel(), options); + + auto out = pyg::sampler::neighbor_sample( + /*rowptr=*/rowptr, + /*col=*/col, + /*seed=*/at::arange(2, 4, options), + /*num_neighbors=*/{2, 2}, + /*node_time=*/c10::nullopt, + /*edge_time=*/edge_time, + /*seed_time=*/at::arange(5, 7, options), + /*edge_weight=*/c10::nullopt, + /*csc=*/false, + /*replace=*/false, + /*directed=*/true, + /*disjoint=*/true); + + // Expect only the earlier neighbors or the same node to be sampled: + auto expected_row = at::tensor({0, 0, 1, 2, 2, 4, 4}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); + auto expected_col = at::tensor({2, 3, 4, 5, 0, 6, 1}, options); + EXPECT_TRUE(at::equal(std::get<1>(out), expected_col)); + auto expected_nodes = + at::tensor({0, 2, 1, 3, 0, 1, 0, 3, 1, 2, 0, 0, 1, 1}, options); + EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes.view({-1, 2}))); + auto expected_edges = at::tensor({4, 5, 6, 2, 3, 4, 5}, options); + EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); +} + TEST(HeteroNeighborTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); @@ -226,7 +267,8 @@ TEST(BiasedNeighborTest, BasicAssertions) { /*col=*/std::get<1>(graph), /*seed=*/at::arange(0, 2, options), /*num_neighbors=*/{1}, - /*time=*/c10::nullopt, + /*node_time=*/c10::nullopt, + /*edge_time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*edge_weight=*/edge_weight); @@ -274,7 +316,8 @@ TEST(HeteroBiasedNeighborTest, BasicAssertions) { /*col_dict=*/col_dict, /*seed_dict=*/seed_dict, /*num_neighbors_dict=*/num_neighbors_dict, - /*time_dict=*/c10::nullopt, + /*node_time_dict=*/c10::nullopt, + /*edge_time_dict=*/c10::nullopt, /*seed_time_dict=*/c10::nullopt, /*edge_weight_dict=*/edge_weight_dict);