From 2a50f4c410752169a80396f636f7bd268c312555 Mon Sep 17 00:00:00 2001 From: ZenoTan Date: Tue, 3 May 2022 16:19:21 +0000 Subject: [PATCH 01/13] init hetero subgraph --- pyg_lib/csrc/sampler/subgraph.cpp | 17 +++++++++++++++++ pyg_lib/csrc/sampler/subgraph.h | 10 ++++++++++ pyg_lib/csrc/utils/types.h | 14 ++++++++++++++ 3 files changed, 41 insertions(+) create mode 100644 pyg_lib/csrc/utils/types.h diff --git a/pyg_lib/csrc/sampler/subgraph.cpp b/pyg_lib/csrc/sampler/subgraph.cpp index dbedb5b6e..a08fc4274 100644 --- a/pyg_lib/csrc/sampler/subgraph.cpp +++ b/pyg_lib/csrc/sampler/subgraph.cpp @@ -25,10 +25,27 @@ std::tuple> subgraph( return op.call(rowptr, col, nodes, return_edge_id); } +c10::Dict>> +hetero_subgraph(const utils::HETERO_TENSOR_TYPE& rowptr, + const utils::HETERO_TENSOR_TYPE& col, + const utils::HETERO_TENSOR_TYPE& nodes, + const c10::Dict return_edge_id) { + c10::Dict>> + out_dict; + return out_dict; +} + TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( "pyg::subgraph(Tensor rowptr, Tensor col, Tensor " "nodes, bool return_edge_id) -> (Tensor, Tensor, Tensor?)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "pyg::hetero_subgraph(Dict(str, Tensor) rowptr, Dict(str, " + "Tensor) col, Dict(str, Tensor) nodes, Dict(str, bool) " + "return_edge_id) -> (Dict(str, Tensor), Dict(str, Tensor), " + "Dict(str, Tensor?))")); } } // namespace sampler diff --git a/pyg_lib/csrc/sampler/subgraph.h b/pyg_lib/csrc/sampler/subgraph.h index a2f8de549..1c0190bfa 100644 --- a/pyg_lib/csrc/sampler/subgraph.h +++ b/pyg_lib/csrc/sampler/subgraph.h @@ -2,6 +2,7 @@ #include #include "pyg_lib/csrc/macros.h" +#include "pyg_lib/csrc/utils/types.h" namespace pyg { namespace sampler { @@ -15,5 +16,14 @@ PYG_API std::tuple> subgraph( const at::Tensor& nodes, const bool return_edge_id = true); +// A heterogeneous version of the above function. +// Returns a dict from each relation type to its result +PYG_API c10::Dict>> +hetero_subgraph(const utils::HETERO_TENSOR_TYPE& rowptr, + const utils::HETERO_TENSOR_TYPE& col, + const utils::HETERO_TENSOR_TYPE& nodes, + const c10::Dict return_edge_id); + } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/utils/types.h b/pyg_lib/csrc/utils/types.h new file mode 100644 index 000000000..ef3d39c2f --- /dev/null +++ b/pyg_lib/csrc/utils/types.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +#include + +namespace pyg { +namespace utils { +using RELATION_TYPE = std::string; + +using HETERO_TENSOR_TYPE = c10::Dict; +} // namespace utils + +} // namespace pyg From cba41ddd78c5677408cec227d64f3637530389d9 Mon Sep 17 00:00:00 2001 From: ZenoTan Date: Thu, 5 May 2022 09:00:20 +0000 Subject: [PATCH 02/13] update --- pyg_lib/csrc/sampler/subgraph.cpp | 12 ++++++------ pyg_lib/csrc/sampler/subgraph.h | 10 +++++----- pyg_lib/csrc/utils/types.h | 23 +++++++++++++++++++++-- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/pyg_lib/csrc/sampler/subgraph.cpp b/pyg_lib/csrc/sampler/subgraph.cpp index a08fc4274..bb04bf724 100644 --- a/pyg_lib/csrc/sampler/subgraph.cpp +++ b/pyg_lib/csrc/sampler/subgraph.cpp @@ -25,13 +25,13 @@ std::tuple> subgraph( return op.call(rowptr, col, nodes, return_edge_id); } -c10::Dict>> -hetero_subgraph(const utils::HETERO_TENSOR_TYPE& rowptr, - const utils::HETERO_TENSOR_TYPE& col, - const utils::HETERO_TENSOR_TYPE& nodes, - const c10::Dict return_edge_id) { - c10::Dict return_edge_id) { + c10::Dict>> out_dict; return out_dict; diff --git a/pyg_lib/csrc/sampler/subgraph.h b/pyg_lib/csrc/sampler/subgraph.h index 1c0190bfa..53a9bb280 100644 --- a/pyg_lib/csrc/sampler/subgraph.h +++ b/pyg_lib/csrc/sampler/subgraph.h @@ -18,12 +18,12 @@ PYG_API std::tuple> subgraph( // A heterogeneous version of the above function. // Returns a dict from each relation type to its result -PYG_API c10::Dict>> -hetero_subgraph(const utils::HETERO_TENSOR_TYPE& rowptr, - const utils::HETERO_TENSOR_TYPE& col, - const utils::HETERO_TENSOR_TYPE& nodes, - const c10::Dict return_edge_id); +hetero_subgraph(const utils::edge_tensor_dict_t& rowptr, + const utils::edge_tensor_dict_t& col, + const utils::node_tensor_dict_t& nodes, + const c10::Dict return_edge_id); } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/utils/types.h b/pyg_lib/csrc/utils/types.h index ef3d39c2f..9d93deb1f 100644 --- a/pyg_lib/csrc/utils/types.h +++ b/pyg_lib/csrc/utils/types.h @@ -6,9 +6,28 @@ namespace pyg { namespace utils { -using RELATION_TYPE = std::string; -using HETERO_TENSOR_TYPE = c10::Dict; +const std::string SPLIT_TOKEN = "__"; + +using edge_t = std::string; +using node_t = std::string; +using rel_t = std::string; + +using edge_tensor_dict_t = c10::Dict; +using node_tensor_dict_t = c10::Dict; + +node_t get_src(const edge_t& e) { + return e.substr(0, e.find_first_of(SPLIT_TOKEN)); +} + +rel_t get_rel(const edge_t& e) { + auto beg = e.find_first_of(SPLIT_TOKEN) + SPLIT_TOKEN.size(); + return e.substr(beg, e.find_last_of(SPLIT_TOKEN) - beg); +} + +node_t get_dst(const edge_t& e) { + return e.substr(e.find_last_of(SPLIT_TOKEN) + SPLIT_TOKEN.size()); +} } // namespace utils } // namespace pyg From f87e8f314c743f8d30004e9b73a0e59f10f85849 Mon Sep 17 00:00:00 2001 From: ZenoTan Date: Thu, 5 May 2022 23:40:29 +0000 Subject: [PATCH 03/13] hetero dispatch logic --- pyg_lib/csrc/sampler/subgraph.cpp | 34 +++-- pyg_lib/csrc/sampler/subgraph.h | 2 +- pyg_lib/csrc/utils/hetero_dispatch.h | 183 +++++++++++++++++++++++++++ test/csrc/sampler/test_subgraph.cpp | 56 ++++++++ 4 files changed, 266 insertions(+), 9 deletions(-) create mode 100644 pyg_lib/csrc/utils/hetero_dispatch.h diff --git a/pyg_lib/csrc/sampler/subgraph.cpp b/pyg_lib/csrc/sampler/subgraph.cpp index bb04bf724..f3b41a3e5 100644 --- a/pyg_lib/csrc/sampler/subgraph.cpp +++ b/pyg_lib/csrc/sampler/subgraph.cpp @@ -1,8 +1,11 @@ #include "subgraph.h" +#include #include #include +#include + namespace pyg { namespace sampler { @@ -11,7 +14,7 @@ std::tuple> subgraph( const at::Tensor& col, const at::Tensor& nodes, const bool return_edge_id) { - at::TensorArg rowptr_t{rowptr, "rowtpr", 1}; + at::TensorArg rowptr_t{rowptr, "rowptr", 1}; at::TensorArg col_t{col, "col", 1}; at::TensorArg nodes_t{nodes, "nodes", 1}; @@ -30,11 +33,27 @@ c10::Dict return_edge_id) { - c10::Dict>> - out_dict; - return out_dict; + const c10::Dict& return_edge_id) { + // Define the homogeneous implementation as a std function to pass the type + // check + std::function>( + const at::Tensor&, const at::Tensor&, const at::Tensor&, bool)> + func = subgraph; + + // Construct an operator + utils::HeteroDispatchOp op(rowptr, col, func); + + // Construct dispatchable arguments + // TODO: We filter source node by assuming hetero graph is a dict of homo + // graph here; both source and destination nodes should be considered when + // filtering a bipartite graph + utils::HeteroDispatchArg + nodes_arg(nodes); + utils::HeteroDispatchArg, bool, + utils::EdgeMode> + edge_id_arg(return_edge_id); + return op(nodes_arg, edge_id_arg); } TORCH_LIBRARY_FRAGMENT(pyg, m) { @@ -44,8 +63,7 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( "pyg::hetero_subgraph(Dict(str, Tensor) rowptr, Dict(str, " "Tensor) col, Dict(str, Tensor) nodes, Dict(str, bool) " - "return_edge_id) -> (Dict(str, Tensor), Dict(str, Tensor), " - "Dict(str, Tensor?))")); + "return_edge_id) -> Dict(str, (Tensor, Tensor, Tensor?))")); } } // namespace sampler diff --git a/pyg_lib/csrc/sampler/subgraph.h b/pyg_lib/csrc/sampler/subgraph.h index 53a9bb280..c303990ca 100644 --- a/pyg_lib/csrc/sampler/subgraph.h +++ b/pyg_lib/csrc/sampler/subgraph.h @@ -23,7 +23,7 @@ PYG_API c10::Dict return_edge_id); + const c10::Dict& return_edge_id); } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/utils/hetero_dispatch.h b/pyg_lib/csrc/utils/hetero_dispatch.h new file mode 100644 index 000000000..122b53aae --- /dev/null +++ b/pyg_lib/csrc/utils/hetero_dispatch.h @@ -0,0 +1,183 @@ +#pragma once + +#include "types.h" + +#include + +namespace pyg { + +namespace utils { + +// Base class for easier type check +struct HeteroDispatchMode {}; + +// List hetero dispatch mode as different types to avoid non-type template +// specialization. +struct SkipMode : public HeteroDispatchMode {}; + +struct NodeSrcMode : public HeteroDispatchMode {}; + +struct NodeDstMode : public HeteroDispatchMode {}; + +struct EdgeMode : public HeteroDispatchMode {}; + +template +struct is_c10_dict : std::false_type {}; + +template +struct is_c10_dict> : std::true_type {}; + +// TODO: Should specialize as if-constexpr when in C++17 +template +class HeteroDispatchArg {}; + +// In SkipMode we do not filter this arg +template +class HeteroDispatchArg { + public: + HeteroDispatchArg(const T& val) : val_(val) {} + + // If we pass the filter, we will obtain the value of the argument. + template + V value_by_edge(const K& key) { + return val_; + } + + bool filter_by_edge(const edge_t& edge) { return true; } + + private: + T val_; +}; + +// In NodeSrcMode we check if source node is in the dict +template +class HeteroDispatchArg { + public: + HeteroDispatchArg(const T& val) : val_(val) { + static_assert(is_c10_dict::value, "Should be a c10::dict"); + } + + template + V value_by_edge(const K& key) { + return val_.at(get_src(key)); + } + + bool filter_by_edge(const edge_t& edge) { + return val_.contains(get_src(edge)); + } + + private: + T val_; +}; + +// In NodeDstMode we check if destination node is in the dict +template +class HeteroDispatchArg { + public: + HeteroDispatchArg(const T& val) : val_(val) { + static_assert(is_c10_dict::value, "Should be a c10::dict"); + } + + template + V value_by_edge(const K& key) { + return val_.at(get_dst(key)); + } + + bool filter_by_edge(const edge_t& edge) { + return val_.contains(get_dst(edge)); + } + + private: + T val_; +}; + +// In EdgeMode we check if edge is in the dict +template +class HeteroDispatchArg { + public: + HeteroDispatchArg(const T& val) : val_(val) { + static_assert(is_c10_dict::value, "Should be a c10::dict"); + } + + template + V value_by_edge(const K& key) { + return val_.at(key); + } + + bool filter_by_edge(const edge_t& edge) { return val_.contains(edge); } + + private: + T val_; +}; + +// The following will help static type checks: +template +struct is_hetero_arg : std::false_type {}; + +// Just check inheritance, a workaround without introducing concepts +template +struct is_hetero_arg> : std::true_type { + static_assert(std::is_base_of::value, + "Must pass a mode for dispatching"); +}; + +// Specialize +template +bool filter_args_by_edge(const edge_t& edge, Args&&... args) {} + +template <> +bool filter_args_by_edge(const edge_t& edge) { + return true; +} + +template +bool filter_args_by_edge(const edge_t& edge, T&& t, Args&&... args) { + static_assert( + is_hetero_arg>>::value, + "args should be HeteroDispatchArg"); + return t.filter_by_edge(edge) && filter_args_by_edge(edge, args...); +} + +template +struct is_std_function : std::false_type {}; + +template +struct is_std_function> : std::true_type {}; + +template +class HeteroDispatchOp { + public: + using result_type = typename T::result_type; + HeteroDispatchOp(const edge_tensor_dict_t& rowptr, + const edge_tensor_dict_t& col, + T op) + : rowptr_(rowptr), col_(col), op_(op) { + // Check early + static_assert(is_std_function::value, "Must pass a function"); + } + + template + c10::Dict operator()(Args&&... args) { + c10::Dict dict; + for (const auto& kv : rowptr_) { + auto edge = kv.key(); + auto rowptr = kv.value(); + auto col = col_.at(edge); + bool pass = filter_args_by_edge(edge, args...); + if (pass) { + result_type res = op_(rowptr, col, args.value_by_edge(edge)...); + dict.insert(edge, res); + } + } + return dict; + } + + private: + edge_tensor_dict_t rowptr_; + edge_tensor_dict_t col_; + T op_; +}; + +} // namespace utils + +} // namespace pyg diff --git a/test/csrc/sampler/test_subgraph.cpp b/test/csrc/sampler/test_subgraph.cpp index ea9235146..baa1732d2 100644 --- a/test/csrc/sampler/test_subgraph.cpp +++ b/test/csrc/sampler/test_subgraph.cpp @@ -20,3 +20,59 @@ TEST(SubgraphTest, BasicAssertions) { auto expected_edge_id = at::tensor({3, 4, 5, 6, 7, 8}, options); EXPECT_TRUE(at::equal(std::get<2>(out).value(), expected_edge_id)); } + +TEST(HeteroSubgraphPassFilterTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + auto nodes = at::arange(1, 5, options); + auto graph = cycle_graph(/*num_nodes=*/6, options); + + pyg::utils::node_t node_name = "node"; + pyg::utils::edge_t edge_name = "node__to__node"; + + pyg::utils::edge_tensor_dict_t rowptr_dict; + rowptr_dict.insert(edge_name, std::get<0>(graph)); + pyg::utils::edge_tensor_dict_t col_dict; + col_dict.insert(edge_name, std::get<1>(graph)); + pyg::utils::edge_tensor_dict_t nodes_dict; + nodes_dict.insert(node_name, nodes); + c10::Dict edge_id_dict; + edge_id_dict.insert(edge_name, true); + + auto res = pyg::sampler::hetero_subgraph(rowptr_dict, col_dict, nodes_dict, + edge_id_dict); + + EXPECT_EQ(res.size(), 1); + auto out = res.at(edge_name); + + auto expected_rowptr = at::tensor({0, 1, 3, 5, 6}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_rowptr)); + auto expected_col = at::tensor({1, 0, 2, 1, 3, 2}, options); + EXPECT_TRUE(at::equal(std::get<1>(out), expected_col)); + auto expected_edge_id = at::tensor({3, 4, 5, 6, 7, 8}, options); + EXPECT_TRUE(at::equal(std::get<2>(out).value(), expected_edge_id)); +} + +TEST(HeteroSubgraphFailFilterTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + auto nodes = at::arange(1, 5, options); + auto graph = cycle_graph(/*num_nodes=*/6, options); + + pyg::utils::node_t node_name = "node"; + pyg::utils::edge_t edge_name = "node123__to456__node321"; + + pyg::utils::edge_tensor_dict_t rowptr_dict; + rowptr_dict.insert(edge_name, std::get<0>(graph)); + pyg::utils::edge_tensor_dict_t col_dict; + col_dict.insert(edge_name, std::get<1>(graph)); + pyg::utils::edge_tensor_dict_t nodes_dict; + nodes_dict.insert(node_name, nodes); + c10::Dict edge_id_dict; + edge_id_dict.insert(edge_name, true); + + auto res = pyg::sampler::hetero_subgraph(rowptr_dict, col_dict, nodes_dict, + edge_id_dict); + + EXPECT_EQ(res.size(), 0); +} From 8a1cb1426119e16166ca57a5d1a46da52438f64f Mon Sep 17 00:00:00 2001 From: ZenoTan Date: Fri, 6 May 2022 12:03:00 +0000 Subject: [PATCH 04/13] update --- CHANGELOG.md | 1 + pyg_lib/csrc/utils/hetero_dispatch.h | 6 ++++++ pyg_lib/csrc/utils/types.h | 5 +++-- test/csrc/utils/test_utils.cpp | 15 +++++++++++++++ 4 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 test/csrc/utils/test_utils.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index eef5243ab..c11ed42ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [Unreleased] ### Added +- Added hetero subgraph kernel ([#43](https://github.com/pyg-team/pyg-lib/pull/43) - Added download script for benchmark data ([#44](https://github.com/pyg-team/pyg-lib/pull/44) - Added `biased sampling` utils ([#38](https://github.com/pyg-team/pyg-lib/pull/38)) - Added `CHANGELOG.md` ([#39](https://github.com/pyg-team/pyg-lib/pull/39)) diff --git a/pyg_lib/csrc/utils/hetero_dispatch.h b/pyg_lib/csrc/utils/hetero_dispatch.h index 122b53aae..5f1e9a05c 100644 --- a/pyg_lib/csrc/utils/hetero_dispatch.h +++ b/pyg_lib/csrc/utils/hetero_dispatch.h @@ -21,6 +21,7 @@ struct NodeDstMode : public HeteroDispatchMode {}; struct EdgeMode : public HeteroDispatchMode {}; +// Check if the argument is a c10::dict so that is could be filtered by an edge type. template struct is_c10_dict : std::false_type {}; @@ -57,11 +58,13 @@ class HeteroDispatchArg { static_assert(is_c10_dict::value, "Should be a c10::dict"); } + // Dict value lookup template V value_by_edge(const K& key) { return val_.at(get_src(key)); } + // Dict if key exists bool filter_by_edge(const edge_t& edge) { return val_.contains(get_src(edge)); } @@ -125,11 +128,13 @@ struct is_hetero_arg> : std::true_type { template bool filter_args_by_edge(const edge_t& edge, Args&&... args) {} +// Stop condition of argument filtering template <> bool filter_args_by_edge(const edge_t& edge) { return true; } +// We filter each argument individually by the given edge using a variadic template template bool filter_args_by_edge(const edge_t& edge, T&& t, Args&&... args) { static_assert( @@ -138,6 +143,7 @@ bool filter_args_by_edge(const edge_t& edge, T&& t, Args&&... args) { return t.filter_by_edge(edge) && filter_args_by_edge(edge, args...); } +// Check if a callable is wrapped by std::function template struct is_std_function : std::false_type {}; diff --git a/pyg_lib/csrc/utils/types.h b/pyg_lib/csrc/utils/types.h index 9d93deb1f..9b0de4925 100644 --- a/pyg_lib/csrc/utils/types.h +++ b/pyg_lib/csrc/utils/types.h @@ -22,11 +22,12 @@ node_t get_src(const edge_t& e) { rel_t get_rel(const edge_t& e) { auto beg = e.find_first_of(SPLIT_TOKEN) + SPLIT_TOKEN.size(); - return e.substr(beg, e.find_last_of(SPLIT_TOKEN) - beg); + return e.substr(beg, + e.find_last_of(SPLIT_TOKEN) - SPLIT_TOKEN.size() + 1 - beg); } node_t get_dst(const edge_t& e) { - return e.substr(e.find_last_of(SPLIT_TOKEN) + SPLIT_TOKEN.size()); + return e.substr(e.find_last_of(SPLIT_TOKEN) + 1); } } // namespace utils diff --git a/test/csrc/utils/test_utils.cpp b/test/csrc/utils/test_utils.cpp new file mode 100644 index 000000000..f1564767b --- /dev/null +++ b/test/csrc/utils/test_utils.cpp @@ -0,0 +1,15 @@ +#include + +#include "pyg_lib/csrc/sampler/subgraph.h" + +TEST(UtilsTypeTest, BasicAssertions) { + pyg::utils::edge_t edge = "node1__to__node2"; + + auto src = pyg::utils::get_src(edge); + auto dst = pyg::utils::get_dst(edge); + auto rel = pyg::utils::get_rel(edge); + + EXPECT_EQ(src, std::string("node1")); + EXPECT_EQ(dst, std::string("node2")); + EXPECT_EQ(rel, std::string("to")); +} From 08faf36da9294d5228724c3c06144606877f294b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 May 2022 12:03:11 +0000 Subject: [PATCH 05/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyg_lib/csrc/utils/hetero_dispatch.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyg_lib/csrc/utils/hetero_dispatch.h b/pyg_lib/csrc/utils/hetero_dispatch.h index 5f1e9a05c..d1a6fe2f0 100644 --- a/pyg_lib/csrc/utils/hetero_dispatch.h +++ b/pyg_lib/csrc/utils/hetero_dispatch.h @@ -21,7 +21,8 @@ struct NodeDstMode : public HeteroDispatchMode {}; struct EdgeMode : public HeteroDispatchMode {}; -// Check if the argument is a c10::dict so that is could be filtered by an edge type. +// Check if the argument is a c10::dict so that is could be filtered by an edge +// type. template struct is_c10_dict : std::false_type {}; @@ -134,7 +135,8 @@ bool filter_args_by_edge(const edge_t& edge) { return true; } -// We filter each argument individually by the given edge using a variadic template +// We filter each argument individually by the given edge using a variadic +// template template bool filter_args_by_edge(const edge_t& edge, T&& t, Args&&... args) { static_assert( From dd2e8b6be9a2faebb9f5a646a84b0359f71707b1 Mon Sep 17 00:00:00 2001 From: ZenoTan Date: Wed, 11 May 2022 23:20:54 +0000 Subject: [PATCH 06/13] refactor --- pyg_lib/csrc/sampler/cpu/mapper.h | 12 ++- pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp | 74 +--------------- pyg_lib/csrc/sampler/subgraph.cpp | 56 ++++++++++-- pyg_lib/csrc/sampler/subgraph.h | 93 +++++++++++++++++++- pyg_lib/csrc/utils/types.h | 6 +- test/csrc/sampler/test_subgraph.cpp | 4 +- 6 files changed, 154 insertions(+), 91 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/mapper.h b/pyg_lib/csrc/sampler/cpu/mapper.h index 47144ea6b..2074f8be3 100644 --- a/pyg_lib/csrc/sampler/cpu/mapper.h +++ b/pyg_lib/csrc/sampler/cpu/mapper.h @@ -10,6 +10,8 @@ namespace sampler { template class Mapper { public: + using type = scalar_t; + Mapper(scalar_t num_nodes, scalar_t num_entries) : num_nodes(num_nodes), num_entries(num_entries) { // Use a some simple heuristic to determine whether we can use a std::vector @@ -23,11 +25,13 @@ class Mapper { void fill(const scalar_t* nodes_data, const scalar_t size) { if (use_vec) { - for (scalar_t i = 0; i < size; ++i) + for (scalar_t i = 0; i < size; ++i) { to_local_vec[nodes_data[i]] = i; + } } else { - for (scalar_t i = 0; i < size; ++i) + for (scalar_t i = 0; i < size; ++i) { to_local_map.insert({nodes_data[i], i}); + } } } @@ -35,14 +39,14 @@ class Mapper { fill(nodes.data_ptr(), nodes.numel()); } - bool exists(const scalar_t& node) { + bool exists(const scalar_t& node) const { if (use_vec) return to_local_vec[node] >= 0; else return to_local_map.count(node) > 0; } - scalar_t map(const scalar_t& node) { + scalar_t map(const scalar_t& node) const { if (use_vec) return to_local_vec[node]; else { diff --git a/pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp b/pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp index 9e67cddea..ccce9a464 100644 --- a/pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp @@ -3,6 +3,7 @@ #include #include "pyg_lib/csrc/sampler/cpu/mapper.h" +#include "pyg_lib/csrc/sampler/subgraph.h" #include "pyg_lib/csrc/utils/cpu/convert.h" namespace pyg { @@ -15,78 +16,7 @@ std::tuple> subgraph_kernel( const at::Tensor& col, const at::Tensor& nodes, const bool return_edge_id) { - TORCH_CHECK(rowptr.is_cpu(), "'rowptr' must be a CPU tensor"); - TORCH_CHECK(col.is_cpu(), "'col' must be a CPU tensor"); - TORCH_CHECK(nodes.is_cpu(), "'nodes' must be a CPU tensor"); - - const auto num_nodes = rowptr.size(0) - 1; - const auto out_rowptr = rowptr.new_empty({nodes.size(0) + 1}); - at::Tensor out_col; - c10::optional out_edge_id = c10::nullopt; - - AT_DISPATCH_INTEGRAL_TYPES(nodes.scalar_type(), "subgraph_kernel", [&] { - auto mapper = pyg::sampler::Mapper(num_nodes, nodes.size(0)); - mapper.fill(nodes); - - const auto rowptr_data = rowptr.data_ptr(); - const auto col_data = col.data_ptr(); - const auto nodes_data = nodes.data_ptr(); - - // We first iterate over all nodes and collect information about the number - // of edges in the induced subgraph. - const auto deg = rowptr.new_empty({nodes.size(0)}); - auto deg_data = deg.data_ptr(); - auto grain_size = at::internal::GRAIN_SIZE; - at::parallel_for(0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { - for (size_t i = _s; i < _e; ++i) { - const auto v = nodes_data[i]; - // Iterate over all neighbors and check if they are part of `nodes`: - scalar_t d = 0; - for (size_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { - if (mapper.exists(col_data[j])) - d++; - } - deg_data[i] = d; - } - }); - - auto out_rowptr_data = out_rowptr.data_ptr(); - out_rowptr_data[0] = 0; - auto tmp = out_rowptr.narrow(0, 1, nodes.size(0)); - at::cumsum_out(tmp, deg, /*dim=*/0); - - out_col = col.new_empty({out_rowptr_data[nodes.size(0)]}); - auto out_col_data = out_col.data_ptr(); - scalar_t* out_edge_id_data; - if (return_edge_id) { - out_edge_id = col.new_empty({out_rowptr_data[nodes.size(0)]}); - out_edge_id_data = out_edge_id.value().data_ptr(); - } - - // Customize `grain_size` based on the work each thread does (it will need - // to find `col.size(0) / nodes.size(0)` neighbors on average). - // TODO Benchmark this customization - grain_size = std::max(out_col.size(0) / nodes.size(0), 1); - grain_size = at::internal::GRAIN_SIZE / grain_size; - at::parallel_for(0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { - for (scalar_t i = _s; i < _e; ++i) { - const auto v = nodes_data[i]; - // Iterate over all neighbors and check if they are part of `nodes`: - scalar_t offset = out_rowptr_data[i]; - for (scalar_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { - const auto w = mapper.map(col_data[j]); - if (w >= 0) { - out_col_data[offset] = w; - if (return_edge_id) - out_edge_id_data[offset] = j; - offset++; - } - } - } - }); - }); - - return std::make_tuple(out_rowptr, out_col, out_edge_id); + return subgraph_bipartite(rowptr, col, nodes, nodes, return_edge_id); } } // namespace diff --git a/pyg_lib/csrc/sampler/subgraph.cpp b/pyg_lib/csrc/sampler/subgraph.cpp index f3b41a3e5..c0494c055 100644 --- a/pyg_lib/csrc/sampler/subgraph.cpp +++ b/pyg_lib/csrc/sampler/subgraph.cpp @@ -28,38 +28,76 @@ std::tuple> subgraph( return op.call(rowptr, col, nodes, return_edge_id); } +std::tuple> +subgraph_bipartite(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& src_nodes, + const at::Tensor& dst_nodes, + const bool return_edge_id) { + TORCH_CHECK(rowptr.is_cpu(), "'rowptr' must be a CPU tensor"); + TORCH_CHECK(col.is_cpu(), "'col' must be a CPU tensor"); + TORCH_CHECK(src_nodes.is_cpu(), "'src_nodes' must be a CPU tensor"); + TORCH_CHECK(dst_nodes.is_cpu(), "'dst_nodes' must be a CPU tensor"); + + const auto num_nodes = rowptr.size(0) - 1; + at::Tensor out_rowptr, out_col; + c10::optional out_edge_id; + + AT_DISPATCH_INTEGRAL_TYPES( + src_nodes.scalar_type(), "subgraph_bipartite", [&] { + // TODO: at::max parallel but still a little expensive + Mapper mapper(at::max(col).item() + 1, + dst_nodes.size(0)); + mapper.fill(dst_nodes); + + auto res = subgraph_with_mapper(rowptr, col, src_nodes, + mapper, return_edge_id); + out_rowptr = std::get<0>(res); + out_col = std::get<1>(res); + out_edge_id = std::get<2>(res); + }); + + return {out_rowptr, out_col, out_edge_id}; +} + c10::Dict>> hetero_subgraph(const utils::edge_tensor_dict_t& rowptr, const utils::edge_tensor_dict_t& col, - const utils::node_tensor_dict_t& nodes, + const utils::node_tensor_dict_t& src_nodes, + const utils::node_tensor_dict_t& dst_nodes, const c10::Dict& return_edge_id) { - // Define the homogeneous implementation as a std function to pass the type + // Define the bipartite implementation as a std function to pass the type // check std::function>( - const at::Tensor&, const at::Tensor&, const at::Tensor&, bool)> - func = subgraph; + const at::Tensor&, const at::Tensor&, const at::Tensor&, + const at::Tensor&, bool)> + func = subgraph_bipartite; // Construct an operator utils::HeteroDispatchOp op(rowptr, col, func); // Construct dispatchable arguments - // TODO: We filter source node by assuming hetero graph is a dict of homo - // graph here; both source and destination nodes should be considered when - // filtering a bipartite graph utils::HeteroDispatchArg - nodes_arg(nodes); + src_nodes_arg(src_nodes); + utils::HeteroDispatchArg + dst_nodes_arg(dst_nodes); utils::HeteroDispatchArg, bool, utils::EdgeMode> edge_id_arg(return_edge_id); - return op(nodes_arg, edge_id_arg); + return op(src_nodes_arg, dst_nodes_arg, edge_id_arg); } TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( "pyg::subgraph(Tensor rowptr, Tensor col, Tensor " "nodes, bool return_edge_id) -> (Tensor, Tensor, Tensor?)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "pyg::subgraph_bipartite(Tensor rowptr, Tensor col, Tensor " + "src_nodes, Tensor dst_nodes, bool return_edge_id) -> (Tensor, Tensor, " + "Tensor?)")); m.def(TORCH_SELECTIVE_SCHEMA( "pyg::hetero_subgraph(Dict(str, Tensor) rowptr, Dict(str, " "Tensor) col, Dict(str, Tensor) nodes, Dict(str, bool) " diff --git a/pyg_lib/csrc/sampler/subgraph.h b/pyg_lib/csrc/sampler/subgraph.h index c303990ca..94bee1825 100644 --- a/pyg_lib/csrc/sampler/subgraph.h +++ b/pyg_lib/csrc/sampler/subgraph.h @@ -1,7 +1,10 @@ #pragma once #include +#include + #include "pyg_lib/csrc/macros.h" +#include "pyg_lib/csrc/sampler/cpu/mapper.h" #include "pyg_lib/csrc/utils/types.h" namespace pyg { @@ -16,14 +19,102 @@ PYG_API std::tuple> subgraph( const at::Tensor& nodes, const bool return_edge_id = true); +// A bipartite version of the above function. +PYG_API std::tuple> +subgraph_bipartite(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& src_nodes, + const at::Tensor& dst_nodes, + const bool return_edge_id); + // A heterogeneous version of the above function. // Returns a dict from each relation type to its result PYG_API c10::Dict>> hetero_subgraph(const utils::edge_tensor_dict_t& rowptr, const utils::edge_tensor_dict_t& col, - const utils::node_tensor_dict_t& nodes, + const utils::node_tensor_dict_t& src_nodes, + const utils::node_tensor_dict_t& dst_nodes, const c10::Dict& return_edge_id); +template +std::tuple> +subgraph_with_mapper(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& nodes, + const Mapper& mapper, + const bool return_edge_id) { + const auto num_nodes = rowptr.size(0) - 1; + const auto out_rowptr = rowptr.new_empty({nodes.size(0) + 1}); + at::Tensor out_col; + c10::optional out_edge_id = c10::nullopt; + + AT_DISPATCH_INTEGRAL_TYPES( + nodes.scalar_type(), "subgraph_kernel_with_mapper", [&] { + const auto rowptr_data = rowptr.data_ptr(); + const auto col_data = col.data_ptr(); + const auto nodes_data = nodes.data_ptr(); + + // We first iterate over all nodes and collect information about the + // number of edges in the induced subgraph. + const auto deg = rowptr.new_empty({nodes.size(0)}); + auto deg_data = deg.data_ptr(); + auto grain_size = at::internal::GRAIN_SIZE; + at::parallel_for( + 0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { + for (size_t i = _s; i < _e; ++i) { + const auto v = nodes_data[i]; + // Iterate over all neighbors and check if they are part of + // `nodes`: + scalar_t d = 0; + for (size_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { + if (mapper.exists(col_data[j])) + d++; + } + deg_data[i] = d; + } + }); + + auto out_rowptr_data = out_rowptr.data_ptr(); + out_rowptr_data[0] = 0; + auto tmp = out_rowptr.narrow(0, 1, nodes.size(0)); + at::cumsum_out(tmp, deg, /*dim=*/0); + + out_col = col.new_empty({out_rowptr_data[nodes.size(0)]}); + auto out_col_data = out_col.data_ptr(); + scalar_t* out_edge_id_data; + if (return_edge_id) { + out_edge_id = col.new_empty({out_rowptr_data[nodes.size(0)]}); + out_edge_id_data = out_edge_id.value().data_ptr(); + } + + // Customize `grain_size` based on the work each thread does (it will + // need to find `col.size(0) / nodes.size(0)` neighbors on average). + // TODO Benchmark this customization + grain_size = std::max(out_col.size(0) / nodes.size(0), 1); + grain_size = at::internal::GRAIN_SIZE / grain_size; + at::parallel_for( + 0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { + for (scalar_t i = _s; i < _e; ++i) { + const auto v = nodes_data[i]; + // Iterate over all neighbors and check if they + // are part of `nodes`: + scalar_t offset = out_rowptr_data[i]; + for (scalar_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { + const auto w = mapper.map(col_data[j]); + if (w >= 0) { + out_col_data[offset] = w; + if (return_edge_id) + out_edge_id_data[offset] = j; + offset++; + } + } + } + }); + }); + + return std::make_tuple(out_rowptr, out_col, out_edge_id); +} + } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/utils/types.h b/pyg_lib/csrc/utils/types.h index 9b0de4925..f661d7ca8 100644 --- a/pyg_lib/csrc/utils/types.h +++ b/pyg_lib/csrc/utils/types.h @@ -16,17 +16,17 @@ using rel_t = std::string; using edge_tensor_dict_t = c10::Dict; using node_tensor_dict_t = c10::Dict; -node_t get_src(const edge_t& e) { +inline node_t get_src(const edge_t& e) { return e.substr(0, e.find_first_of(SPLIT_TOKEN)); } -rel_t get_rel(const edge_t& e) { +inline rel_t get_rel(const edge_t& e) { auto beg = e.find_first_of(SPLIT_TOKEN) + SPLIT_TOKEN.size(); return e.substr(beg, e.find_last_of(SPLIT_TOKEN) - SPLIT_TOKEN.size() + 1 - beg); } -node_t get_dst(const edge_t& e) { +inline node_t get_dst(const edge_t& e) { return e.substr(e.find_last_of(SPLIT_TOKEN) + 1); } } // namespace utils diff --git a/test/csrc/sampler/test_subgraph.cpp b/test/csrc/sampler/test_subgraph.cpp index baa1732d2..6e46638b2 100644 --- a/test/csrc/sampler/test_subgraph.cpp +++ b/test/csrc/sampler/test_subgraph.cpp @@ -40,7 +40,7 @@ TEST(HeteroSubgraphPassFilterTest, BasicAssertions) { edge_id_dict.insert(edge_name, true); auto res = pyg::sampler::hetero_subgraph(rowptr_dict, col_dict, nodes_dict, - edge_id_dict); + nodes_dict, edge_id_dict); EXPECT_EQ(res.size(), 1); auto out = res.at(edge_name); @@ -72,7 +72,7 @@ TEST(HeteroSubgraphFailFilterTest, BasicAssertions) { edge_id_dict.insert(edge_name, true); auto res = pyg::sampler::hetero_subgraph(rowptr_dict, col_dict, nodes_dict, - edge_id_dict); + nodes_dict, edge_id_dict); EXPECT_EQ(res.size(), 0); } From d1c98cc07b0660d0362f8f2350b46662585d8e29 Mon Sep 17 00:00:00 2001 From: ZenoTan Date: Thu, 12 May 2022 12:24:41 +0000 Subject: [PATCH 07/13] better structure --- pyg_lib/csrc/sampler/subgraph.cpp | 79 +++++++++++++++++++++++++++++++ pyg_lib/csrc/sampler/subgraph.h | 73 +--------------------------- 2 files changed, 80 insertions(+), 72 deletions(-) diff --git a/pyg_lib/csrc/sampler/subgraph.cpp b/pyg_lib/csrc/sampler/subgraph.cpp index c0494c055..7c1cfe668 100644 --- a/pyg_lib/csrc/sampler/subgraph.cpp +++ b/pyg_lib/csrc/sampler/subgraph.cpp @@ -9,6 +9,85 @@ namespace pyg { namespace sampler { +template +std::tuple> +subgraph_with_mapper(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& nodes, + const Mapper& mapper, + const bool return_edge_id) { + const auto num_nodes = rowptr.size(0) - 1; + const auto out_rowptr = rowptr.new_empty({nodes.size(0) + 1}); + at::Tensor out_col; + c10::optional out_edge_id = c10::nullopt; + + AT_DISPATCH_INTEGRAL_TYPES( + nodes.scalar_type(), "subgraph_kernel_with_mapper", [&] { + const auto rowptr_data = rowptr.data_ptr(); + const auto col_data = col.data_ptr(); + const auto nodes_data = nodes.data_ptr(); + + // We first iterate over all nodes and collect information about the + // number of edges in the induced subgraph. + const auto deg = rowptr.new_empty({nodes.size(0)}); + auto deg_data = deg.data_ptr(); + auto grain_size = at::internal::GRAIN_SIZE; + at::parallel_for( + 0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { + for (size_t i = _s; i < _e; ++i) { + const auto v = nodes_data[i]; + // Iterate over all neighbors and check if they are part of + // `nodes`: + scalar_t d = 0; + for (size_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { + if (mapper.exists(col_data[j])) + d++; + } + deg_data[i] = d; + } + }); + + auto out_rowptr_data = out_rowptr.data_ptr(); + out_rowptr_data[0] = 0; + auto tmp = out_rowptr.narrow(0, 1, nodes.size(0)); + at::cumsum_out(tmp, deg, /*dim=*/0); + + out_col = col.new_empty({out_rowptr_data[nodes.size(0)]}); + auto out_col_data = out_col.data_ptr(); + scalar_t* out_edge_id_data; + if (return_edge_id) { + out_edge_id = col.new_empty({out_rowptr_data[nodes.size(0)]}); + out_edge_id_data = out_edge_id.value().data_ptr(); + } + + // Customize `grain_size` based on the work each thread does (it will + // need to find `col.size(0) / nodes.size(0)` neighbors on average). + // TODO Benchmark this customization + grain_size = std::max(out_col.size(0) / nodes.size(0), 1); + grain_size = at::internal::GRAIN_SIZE / grain_size; + at::parallel_for( + 0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { + for (scalar_t i = _s; i < _e; ++i) { + const auto v = nodes_data[i]; + // Iterate over all neighbors and check if they + // are part of `nodes`: + scalar_t offset = out_rowptr_data[i]; + for (scalar_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { + const auto w = mapper.map(col_data[j]); + if (w >= 0) { + out_col_data[offset] = w; + if (return_edge_id) + out_edge_id_data[offset] = j; + offset++; + } + } + } + }); + }); + + return std::make_tuple(out_rowptr, out_col, out_edge_id); +} + std::tuple> subgraph( const at::Tensor& rowptr, const at::Tensor& col, diff --git a/pyg_lib/csrc/sampler/subgraph.h b/pyg_lib/csrc/sampler/subgraph.h index 94bee1825..0b5b93a5b 100644 --- a/pyg_lib/csrc/sampler/subgraph.h +++ b/pyg_lib/csrc/sampler/subgraph.h @@ -43,78 +43,7 @@ subgraph_with_mapper(const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& nodes, const Mapper& mapper, - const bool return_edge_id) { - const auto num_nodes = rowptr.size(0) - 1; - const auto out_rowptr = rowptr.new_empty({nodes.size(0) + 1}); - at::Tensor out_col; - c10::optional out_edge_id = c10::nullopt; - - AT_DISPATCH_INTEGRAL_TYPES( - nodes.scalar_type(), "subgraph_kernel_with_mapper", [&] { - const auto rowptr_data = rowptr.data_ptr(); - const auto col_data = col.data_ptr(); - const auto nodes_data = nodes.data_ptr(); - - // We first iterate over all nodes and collect information about the - // number of edges in the induced subgraph. - const auto deg = rowptr.new_empty({nodes.size(0)}); - auto deg_data = deg.data_ptr(); - auto grain_size = at::internal::GRAIN_SIZE; - at::parallel_for( - 0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { - for (size_t i = _s; i < _e; ++i) { - const auto v = nodes_data[i]; - // Iterate over all neighbors and check if they are part of - // `nodes`: - scalar_t d = 0; - for (size_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { - if (mapper.exists(col_data[j])) - d++; - } - deg_data[i] = d; - } - }); - - auto out_rowptr_data = out_rowptr.data_ptr(); - out_rowptr_data[0] = 0; - auto tmp = out_rowptr.narrow(0, 1, nodes.size(0)); - at::cumsum_out(tmp, deg, /*dim=*/0); - - out_col = col.new_empty({out_rowptr_data[nodes.size(0)]}); - auto out_col_data = out_col.data_ptr(); - scalar_t* out_edge_id_data; - if (return_edge_id) { - out_edge_id = col.new_empty({out_rowptr_data[nodes.size(0)]}); - out_edge_id_data = out_edge_id.value().data_ptr(); - } - - // Customize `grain_size` based on the work each thread does (it will - // need to find `col.size(0) / nodes.size(0)` neighbors on average). - // TODO Benchmark this customization - grain_size = std::max(out_col.size(0) / nodes.size(0), 1); - grain_size = at::internal::GRAIN_SIZE / grain_size; - at::parallel_for( - 0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { - for (scalar_t i = _s; i < _e; ++i) { - const auto v = nodes_data[i]; - // Iterate over all neighbors and check if they - // are part of `nodes`: - scalar_t offset = out_rowptr_data[i]; - for (scalar_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { - const auto w = mapper.map(col_data[j]); - if (w >= 0) { - out_col_data[offset] = w; - if (return_edge_id) - out_edge_id_data[offset] = j; - offset++; - } - } - } - }); - }); - - return std::make_tuple(out_rowptr, out_col, out_edge_id); -} + const bool return_edge_id); } // namespace sampler } // namespace pyg From 0a4bc0169bef3ea652ec4048fba70f77638ec7e8 Mon Sep 17 00:00:00 2001 From: ZenoTan Date: Thu, 12 May 2022 22:30:51 +0000 Subject: [PATCH 08/13] fix type name --- pyg_lib/csrc/sampler/cpu/mapper.h | 2 -- pyg_lib/csrc/sampler/subgraph.cpp | 18 ++++++++--------- pyg_lib/csrc/sampler/subgraph.h | 12 +++++------ pyg_lib/csrc/utils/hetero_dispatch.h | 30 ++++++++++++++-------------- pyg_lib/csrc/utils/types.h | 14 ++++++------- test/csrc/sampler/test_subgraph.cpp | 24 +++++++++++----------- test/csrc/utils/test_utils.cpp | 2 +- 7 files changed, 50 insertions(+), 52 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/mapper.h b/pyg_lib/csrc/sampler/cpu/mapper.h index 2074f8be3..a54fe0f8f 100644 --- a/pyg_lib/csrc/sampler/cpu/mapper.h +++ b/pyg_lib/csrc/sampler/cpu/mapper.h @@ -10,8 +10,6 @@ namespace sampler { template class Mapper { public: - using type = scalar_t; - Mapper(scalar_t num_nodes, scalar_t num_entries) : num_nodes(num_nodes), num_entries(num_entries) { // Use a some simple heuristic to determine whether we can use a std::vector diff --git a/pyg_lib/csrc/sampler/subgraph.cpp b/pyg_lib/csrc/sampler/subgraph.cpp index 7c1cfe668..13064df53 100644 --- a/pyg_lib/csrc/sampler/subgraph.cpp +++ b/pyg_lib/csrc/sampler/subgraph.cpp @@ -139,13 +139,13 @@ subgraph_bipartite(const at::Tensor& rowptr, return {out_rowptr, out_col, out_edge_id}; } -c10::Dict>> -hetero_subgraph(const utils::edge_tensor_dict_t& rowptr, - const utils::edge_tensor_dict_t& col, - const utils::node_tensor_dict_t& src_nodes, - const utils::node_tensor_dict_t& dst_nodes, - const c10::Dict& return_edge_id) { +hetero_subgraph(const utils::EdgeTensorDict& rowptr, + const utils::EdgeTensorDict& col, + const utils::NodeTensorDict& src_nodes, + const utils::NodeTensorDict& dst_nodes, + const c10::Dict& return_edge_id) { // Define the bipartite implementation as a std function to pass the type // check std::function>( @@ -157,13 +157,13 @@ hetero_subgraph(const utils::edge_tensor_dict_t& rowptr, utils::HeteroDispatchOp op(rowptr, col, func); // Construct dispatchable arguments - utils::HeteroDispatchArg src_nodes_arg(src_nodes); - utils::HeteroDispatchArg dst_nodes_arg(dst_nodes); - utils::HeteroDispatchArg, bool, + utils::HeteroDispatchArg, bool, utils::EdgeMode> edge_id_arg(return_edge_id); return op(src_nodes_arg, dst_nodes_arg, edge_id_arg); diff --git a/pyg_lib/csrc/sampler/subgraph.h b/pyg_lib/csrc/sampler/subgraph.h index 0b5b93a5b..a990b2121 100644 --- a/pyg_lib/csrc/sampler/subgraph.h +++ b/pyg_lib/csrc/sampler/subgraph.h @@ -29,13 +29,13 @@ subgraph_bipartite(const at::Tensor& rowptr, // A heterogeneous version of the above function. // Returns a dict from each relation type to its result -PYG_API c10::Dict>> -hetero_subgraph(const utils::edge_tensor_dict_t& rowptr, - const utils::edge_tensor_dict_t& col, - const utils::node_tensor_dict_t& src_nodes, - const utils::node_tensor_dict_t& dst_nodes, - const c10::Dict& return_edge_id); +hetero_subgraph(const utils::EdgeTensorDict& rowptr, + const utils::EdgeTensorDict& col, + const utils::NodeTensorDict& src_nodes, + const utils::NodeTensorDict& dst_nodes, + const c10::Dict& return_edge_id); template std::tuple> diff --git a/pyg_lib/csrc/utils/hetero_dispatch.h b/pyg_lib/csrc/utils/hetero_dispatch.h index d1a6fe2f0..f7aa3d046 100644 --- a/pyg_lib/csrc/utils/hetero_dispatch.h +++ b/pyg_lib/csrc/utils/hetero_dispatch.h @@ -45,7 +45,7 @@ class HeteroDispatchArg { return val_; } - bool filter_by_edge(const edge_t& edge) { return true; } + bool filter_by_edge(const EdgeType& edge) { return true; } private: T val_; @@ -66,7 +66,7 @@ class HeteroDispatchArg { } // Dict if key exists - bool filter_by_edge(const edge_t& edge) { + bool filter_by_edge(const EdgeType& edge) { return val_.contains(get_src(edge)); } @@ -87,7 +87,7 @@ class HeteroDispatchArg { return val_.at(get_dst(key)); } - bool filter_by_edge(const edge_t& edge) { + bool filter_by_edge(const EdgeType& edge) { return val_.contains(get_dst(edge)); } @@ -108,7 +108,7 @@ class HeteroDispatchArg { return val_.at(key); } - bool filter_by_edge(const edge_t& edge) { return val_.contains(edge); } + bool filter_by_edge(const EdgeType& edge) { return val_.contains(edge); } private: T val_; @@ -127,18 +127,18 @@ struct is_hetero_arg> : std::true_type { // Specialize template -bool filter_args_by_edge(const edge_t& edge, Args&&... args) {} +bool filter_args_by_edge(const EdgeType& edge, Args&&... args) {} // Stop condition of argument filtering template <> -bool filter_args_by_edge(const edge_t& edge) { +bool filter_args_by_edge(const EdgeType& edge) { return true; } // We filter each argument individually by the given edge using a variadic // template template -bool filter_args_by_edge(const edge_t& edge, T&& t, Args&&... args) { +bool filter_args_by_edge(const EdgeType& edge, T&& t, Args&&... args) { static_assert( is_hetero_arg>>::value, "args should be HeteroDispatchArg"); @@ -155,9 +155,9 @@ struct is_std_function> : std::true_type {}; template class HeteroDispatchOp { public: - using result_type = typename T::result_type; - HeteroDispatchOp(const edge_tensor_dict_t& rowptr, - const edge_tensor_dict_t& col, + using ResultType = typename T::result_type; + HeteroDispatchOp(const EdgeTensorDict& rowptr, + const EdgeTensorDict& col, T op) : rowptr_(rowptr), col_(col), op_(op) { // Check early @@ -165,15 +165,15 @@ class HeteroDispatchOp { } template - c10::Dict operator()(Args&&... args) { - c10::Dict dict; + c10::Dict operator()(Args&&... args) { + c10::Dict dict; for (const auto& kv : rowptr_) { auto edge = kv.key(); auto rowptr = kv.value(); auto col = col_.at(edge); bool pass = filter_args_by_edge(edge, args...); if (pass) { - result_type res = op_(rowptr, col, args.value_by_edge(edge)...); + ResultType res = op_(rowptr, col, args.value_by_edge(edge)...); dict.insert(edge, res); } } @@ -181,8 +181,8 @@ class HeteroDispatchOp { } private: - edge_tensor_dict_t rowptr_; - edge_tensor_dict_t col_; + EdgeTensorDict rowptr_; + EdgeTensorDict col_; T op_; }; diff --git a/pyg_lib/csrc/utils/types.h b/pyg_lib/csrc/utils/types.h index f661d7ca8..0a1ae3915 100644 --- a/pyg_lib/csrc/utils/types.h +++ b/pyg_lib/csrc/utils/types.h @@ -9,24 +9,24 @@ namespace utils { const std::string SPLIT_TOKEN = "__"; -using edge_t = std::string; -using node_t = std::string; +using EdgeType = std::string; +using NodeType = std::string; using rel_t = std::string; -using edge_tensor_dict_t = c10::Dict; -using node_tensor_dict_t = c10::Dict; +using EdgeTensorDict = c10::Dict; +using NodeTensorDict = c10::Dict; -inline node_t get_src(const edge_t& e) { +inline NodeType get_src(const EdgeType& e) { return e.substr(0, e.find_first_of(SPLIT_TOKEN)); } -inline rel_t get_rel(const edge_t& e) { +inline rel_t get_rel(const EdgeType& e) { auto beg = e.find_first_of(SPLIT_TOKEN) + SPLIT_TOKEN.size(); return e.substr(beg, e.find_last_of(SPLIT_TOKEN) - SPLIT_TOKEN.size() + 1 - beg); } -inline node_t get_dst(const edge_t& e) { +inline NodeType get_dst(const EdgeType& e) { return e.substr(e.find_last_of(SPLIT_TOKEN) + 1); } } // namespace utils diff --git a/test/csrc/sampler/test_subgraph.cpp b/test/csrc/sampler/test_subgraph.cpp index 6e46638b2..25c143c28 100644 --- a/test/csrc/sampler/test_subgraph.cpp +++ b/test/csrc/sampler/test_subgraph.cpp @@ -27,16 +27,16 @@ TEST(HeteroSubgraphPassFilterTest, BasicAssertions) { auto nodes = at::arange(1, 5, options); auto graph = cycle_graph(/*num_nodes=*/6, options); - pyg::utils::node_t node_name = "node"; - pyg::utils::edge_t edge_name = "node__to__node"; + pyg::utils::NodeType node_name = "node"; + pyg::utils::EdgeType edge_name = "node__to__node"; - pyg::utils::edge_tensor_dict_t rowptr_dict; + pyg::utils::EdgeTensorDict rowptr_dict; rowptr_dict.insert(edge_name, std::get<0>(graph)); - pyg::utils::edge_tensor_dict_t col_dict; + pyg::utils::EdgeTensorDict col_dict; col_dict.insert(edge_name, std::get<1>(graph)); - pyg::utils::edge_tensor_dict_t nodes_dict; + pyg::utils::EdgeTensorDict nodes_dict; nodes_dict.insert(node_name, nodes); - c10::Dict edge_id_dict; + c10::Dict edge_id_dict; edge_id_dict.insert(edge_name, true); auto res = pyg::sampler::hetero_subgraph(rowptr_dict, col_dict, nodes_dict, @@ -59,16 +59,16 @@ TEST(HeteroSubgraphFailFilterTest, BasicAssertions) { auto nodes = at::arange(1, 5, options); auto graph = cycle_graph(/*num_nodes=*/6, options); - pyg::utils::node_t node_name = "node"; - pyg::utils::edge_t edge_name = "node123__to456__node321"; + pyg::utils::NodeType node_name = "node"; + pyg::utils::EdgeType edge_name = "node123__to456__node321"; - pyg::utils::edge_tensor_dict_t rowptr_dict; + pyg::utils::EdgeTensorDict rowptr_dict; rowptr_dict.insert(edge_name, std::get<0>(graph)); - pyg::utils::edge_tensor_dict_t col_dict; + pyg::utils::EdgeTensorDict col_dict; col_dict.insert(edge_name, std::get<1>(graph)); - pyg::utils::edge_tensor_dict_t nodes_dict; + pyg::utils::EdgeTensorDict nodes_dict; nodes_dict.insert(node_name, nodes); - c10::Dict edge_id_dict; + c10::Dict edge_id_dict; edge_id_dict.insert(edge_name, true); auto res = pyg::sampler::hetero_subgraph(rowptr_dict, col_dict, nodes_dict, diff --git a/test/csrc/utils/test_utils.cpp b/test/csrc/utils/test_utils.cpp index f1564767b..06ff2a803 100644 --- a/test/csrc/utils/test_utils.cpp +++ b/test/csrc/utils/test_utils.cpp @@ -3,7 +3,7 @@ #include "pyg_lib/csrc/sampler/subgraph.h" TEST(UtilsTypeTest, BasicAssertions) { - pyg::utils::edge_t edge = "node1__to__node2"; + pyg::utils::EdgeType edge = "node1__to__node2"; auto src = pyg::utils::get_src(edge); auto dst = pyg::utils::get_dst(edge); From 1f20313b7d11b3d3641d29f659aabe3d62e20516 Mon Sep 17 00:00:00 2001 From: ZenoTan Date: Fri, 13 May 2022 11:08:09 +0000 Subject: [PATCH 09/13] structure --- pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp | 115 ++++++++++++++++++- pyg_lib/csrc/sampler/subgraph.cpp | 112 ++---------------- pyg_lib/csrc/sampler/subgraph.h | 8 -- 3 files changed, 125 insertions(+), 110 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp b/pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp index ccce9a464..98330cca2 100644 --- a/pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp @@ -11,18 +11,131 @@ namespace sampler { namespace { +template +std::tuple> +subgraph_with_mapper(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& nodes, + const Mapper& mapper, + const bool return_edge_id) { + const auto num_nodes = rowptr.size(0) - 1; + const auto out_rowptr = rowptr.new_empty({nodes.size(0) + 1}); + at::Tensor out_col; + c10::optional out_edge_id = c10::nullopt; + + AT_DISPATCH_INTEGRAL_TYPES( + nodes.scalar_type(), "subgraph_kernel_with_mapper", [&] { + const auto rowptr_data = rowptr.data_ptr(); + const auto col_data = col.data_ptr(); + const auto nodes_data = nodes.data_ptr(); + + // We first iterate over all nodes and collect information about the + // number of edges in the induced subgraph. + const auto deg = rowptr.new_empty({nodes.size(0)}); + auto deg_data = deg.data_ptr(); + auto grain_size = at::internal::GRAIN_SIZE; + at::parallel_for( + 0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { + for (size_t i = _s; i < _e; ++i) { + const auto v = nodes_data[i]; + // Iterate over all neighbors and check if they are part of + // `nodes`: + scalar_t d = 0; + for (size_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { + if (mapper.exists(col_data[j])) + d++; + } + deg_data[i] = d; + } + }); + + auto out_rowptr_data = out_rowptr.data_ptr(); + out_rowptr_data[0] = 0; + auto tmp = out_rowptr.narrow(0, 1, nodes.size(0)); + at::cumsum_out(tmp, deg, /*dim=*/0); + + out_col = col.new_empty({out_rowptr_data[nodes.size(0)]}); + auto out_col_data = out_col.data_ptr(); + scalar_t* out_edge_id_data; + if (return_edge_id) { + out_edge_id = col.new_empty({out_rowptr_data[nodes.size(0)]}); + out_edge_id_data = out_edge_id.value().data_ptr(); + } + + // Customize `grain_size` based on the work each thread does (it will + // need to find `col.size(0) / nodes.size(0)` neighbors on average). + // TODO Benchmark this customization + grain_size = std::max(out_col.size(0) / nodes.size(0), 1); + grain_size = at::internal::GRAIN_SIZE / grain_size; + at::parallel_for( + 0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { + for (scalar_t i = _s; i < _e; ++i) { + const auto v = nodes_data[i]; + // Iterate over all neighbors and check if they + // are part of `nodes`: + scalar_t offset = out_rowptr_data[i]; + for (scalar_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { + const auto w = mapper.map(col_data[j]); + if (w >= 0) { + out_col_data[offset] = w; + if (return_edge_id) + out_edge_id_data[offset] = j; + offset++; + } + } + } + }); + }); + + return std::make_tuple(out_rowptr, out_col, out_edge_id); +} + +std::tuple> +subgraph_bipartite_kernel(const at::Tensor& rowptr, + const at::Tensor& col, + const at::Tensor& src_nodes, + const at::Tensor& dst_nodes, + const bool return_edge_id) { + TORCH_CHECK(rowptr.is_cpu(), "'rowptr' must be a CPU tensor"); + TORCH_CHECK(col.is_cpu(), "'col' must be a CPU tensor"); + TORCH_CHECK(src_nodes.is_cpu(), "'src_nodes' must be a CPU tensor"); + TORCH_CHECK(dst_nodes.is_cpu(), "'dst_nodes' must be a CPU tensor"); + + const auto num_nodes = rowptr.size(0) - 1; + at::Tensor out_rowptr, out_col; + c10::optional out_edge_id; + + AT_DISPATCH_INTEGRAL_TYPES( + src_nodes.scalar_type(), "subgraph_bipartite_kernel", [&] { + // TODO: at::max parallel but still a little expensive + Mapper mapper(at::max(col).item() + 1, + dst_nodes.size(0)); + mapper.fill(dst_nodes); + + auto res = subgraph_with_mapper(rowptr, col, src_nodes, + mapper, return_edge_id); + out_rowptr = std::get<0>(res); + out_col = std::get<1>(res); + out_edge_id = std::get<2>(res); + }); + + return {out_rowptr, out_col, out_edge_id}; +} + std::tuple> subgraph_kernel( const at::Tensor& rowptr, const at::Tensor& col, const at::Tensor& nodes, const bool return_edge_id) { - return subgraph_bipartite(rowptr, col, nodes, nodes, return_edge_id); + return subgraph_bipartite_kernel(rowptr, col, nodes, nodes, return_edge_id); } } // namespace TORCH_LIBRARY_IMPL(pyg, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("pyg::subgraph"), TORCH_FN(subgraph_kernel)); + m.impl(TORCH_SELECTIVE_NAME("pyg::subgraph_bipartite"), + TORCH_FN(subgraph_bipartite_kernel)); } } // namespace sampler diff --git a/pyg_lib/csrc/sampler/subgraph.cpp b/pyg_lib/csrc/sampler/subgraph.cpp index 13064df53..211543f62 100644 --- a/pyg_lib/csrc/sampler/subgraph.cpp +++ b/pyg_lib/csrc/sampler/subgraph.cpp @@ -9,85 +9,6 @@ namespace pyg { namespace sampler { -template -std::tuple> -subgraph_with_mapper(const at::Tensor& rowptr, - const at::Tensor& col, - const at::Tensor& nodes, - const Mapper& mapper, - const bool return_edge_id) { - const auto num_nodes = rowptr.size(0) - 1; - const auto out_rowptr = rowptr.new_empty({nodes.size(0) + 1}); - at::Tensor out_col; - c10::optional out_edge_id = c10::nullopt; - - AT_DISPATCH_INTEGRAL_TYPES( - nodes.scalar_type(), "subgraph_kernel_with_mapper", [&] { - const auto rowptr_data = rowptr.data_ptr(); - const auto col_data = col.data_ptr(); - const auto nodes_data = nodes.data_ptr(); - - // We first iterate over all nodes and collect information about the - // number of edges in the induced subgraph. - const auto deg = rowptr.new_empty({nodes.size(0)}); - auto deg_data = deg.data_ptr(); - auto grain_size = at::internal::GRAIN_SIZE; - at::parallel_for( - 0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { - for (size_t i = _s; i < _e; ++i) { - const auto v = nodes_data[i]; - // Iterate over all neighbors and check if they are part of - // `nodes`: - scalar_t d = 0; - for (size_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { - if (mapper.exists(col_data[j])) - d++; - } - deg_data[i] = d; - } - }); - - auto out_rowptr_data = out_rowptr.data_ptr(); - out_rowptr_data[0] = 0; - auto tmp = out_rowptr.narrow(0, 1, nodes.size(0)); - at::cumsum_out(tmp, deg, /*dim=*/0); - - out_col = col.new_empty({out_rowptr_data[nodes.size(0)]}); - auto out_col_data = out_col.data_ptr(); - scalar_t* out_edge_id_data; - if (return_edge_id) { - out_edge_id = col.new_empty({out_rowptr_data[nodes.size(0)]}); - out_edge_id_data = out_edge_id.value().data_ptr(); - } - - // Customize `grain_size` based on the work each thread does (it will - // need to find `col.size(0) / nodes.size(0)` neighbors on average). - // TODO Benchmark this customization - grain_size = std::max(out_col.size(0) / nodes.size(0), 1); - grain_size = at::internal::GRAIN_SIZE / grain_size; - at::parallel_for( - 0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { - for (scalar_t i = _s; i < _e; ++i) { - const auto v = nodes_data[i]; - // Iterate over all neighbors and check if they - // are part of `nodes`: - scalar_t offset = out_rowptr_data[i]; - for (scalar_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { - const auto w = mapper.map(col_data[j]); - if (w >= 0) { - out_col_data[offset] = w; - if (return_edge_id) - out_edge_id_data[offset] = j; - offset++; - } - } - } - }); - }); - - return std::make_tuple(out_rowptr, out_col, out_edge_id); -} - std::tuple> subgraph( const at::Tensor& rowptr, const at::Tensor& col, @@ -113,30 +34,19 @@ subgraph_bipartite(const at::Tensor& rowptr, const at::Tensor& src_nodes, const at::Tensor& dst_nodes, const bool return_edge_id) { - TORCH_CHECK(rowptr.is_cpu(), "'rowptr' must be a CPU tensor"); - TORCH_CHECK(col.is_cpu(), "'col' must be a CPU tensor"); - TORCH_CHECK(src_nodes.is_cpu(), "'src_nodes' must be a CPU tensor"); - TORCH_CHECK(dst_nodes.is_cpu(), "'dst_nodes' must be a CPU tensor"); - - const auto num_nodes = rowptr.size(0) - 1; - at::Tensor out_rowptr, out_col; - c10::optional out_edge_id; - - AT_DISPATCH_INTEGRAL_TYPES( - src_nodes.scalar_type(), "subgraph_bipartite", [&] { - // TODO: at::max parallel but still a little expensive - Mapper mapper(at::max(col).item() + 1, - dst_nodes.size(0)); - mapper.fill(dst_nodes); + at::TensorArg rowptr_t{rowptr, "rowptr", 1}; + at::TensorArg col_t{col, "col", 1}; + at::TensorArg src_nodes_t{src_nodes, "src_nodes", 1}; + at::TensorArg dst_nodes_t{dst_nodes, "dst_nodes", 1}; - auto res = subgraph_with_mapper(rowptr, col, src_nodes, - mapper, return_edge_id); - out_rowptr = std::get<0>(res); - out_col = std::get<1>(res); - out_edge_id = std::get<2>(res); - }); + at::CheckedFrom c = "subgraph_bipartite"; + at::checkAllDefined(c, {rowptr_t, col_t, src_nodes_t, dst_nodes_t}); + at::checkAllSameType(c, {rowptr_t, col_t, src_nodes_t, dst_nodes_t}); - return {out_rowptr, out_col, out_edge_id}; + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("pyg::subgraph_bipartite", "") + .typed(); + return op.call(rowptr, col, src_nodes, dst_nodes, return_edge_id); } c10::Dict& return_edge_id); -template -std::tuple> -subgraph_with_mapper(const at::Tensor& rowptr, - const at::Tensor& col, - const at::Tensor& nodes, - const Mapper& mapper, - const bool return_edge_id); - } // namespace sampler } // namespace pyg From f8f905964cf5941af100738518514f31626d1b10 Mon Sep 17 00:00:00 2001 From: ZenoTan Date: Fri, 13 May 2022 13:28:10 +0000 Subject: [PATCH 10/13] simplify code using etype loop --- pyg_lib/csrc/sampler/subgraph.cpp | 35 +++++++----- pyg_lib/csrc/utils/hetero_dispatch.h | 81 ++++++++++------------------ pyg_lib/csrc/utils/types.h | 4 +- 3 files changed, 50 insertions(+), 70 deletions(-) diff --git a/pyg_lib/csrc/sampler/subgraph.cpp b/pyg_lib/csrc/sampler/subgraph.cpp index 211543f62..7ee90a464 100644 --- a/pyg_lib/csrc/sampler/subgraph.cpp +++ b/pyg_lib/csrc/sampler/subgraph.cpp @@ -56,15 +56,9 @@ hetero_subgraph(const utils::EdgeTensorDict& rowptr, const utils::NodeTensorDict& src_nodes, const utils::NodeTensorDict& dst_nodes, const c10::Dict& return_edge_id) { - // Define the bipartite implementation as a std function to pass the type - // check - std::function>( - const at::Tensor&, const at::Tensor&, const at::Tensor&, - const at::Tensor&, bool)> - func = subgraph_bipartite; - - // Construct an operator - utils::HeteroDispatchOp op(rowptr, col, func); + c10::Dict>> + res; // Construct dispatchable arguments utils::HeteroDispatchArg, bool, utils::EdgeMode> edge_id_arg(return_edge_id); - return op(src_nodes_arg, dst_nodes_arg, edge_id_arg); + + for (const auto& kv : rowptr) { + const auto& edge_type = kv.key(); + bool pass = filter_args_by_edge(edge_type, src_nodes_arg, dst_nodes_arg, + edge_id_arg); + if (pass) { + auto vals = value_args_by_edge(edge_type, src_nodes_arg, dst_nodes_arg, + edge_id_arg); + const auto& r = rowptr.at(edge_type); + const auto& c = col.at(edge_type); + res.insert(edge_type, + subgraph_bipartite(r, c, std::get<0>(vals), std::get<1>(vals), + std::get<2>(vals))); + } + } + + return res; } TORCH_LIBRARY_FRAGMENT(pyg, m) { @@ -87,10 +97,7 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { "pyg::subgraph_bipartite(Tensor rowptr, Tensor col, Tensor " "src_nodes, Tensor dst_nodes, bool return_edge_id) -> (Tensor, Tensor, " "Tensor?)")); - m.def(TORCH_SELECTIVE_SCHEMA( - "pyg::hetero_subgraph(Dict(str, Tensor) rowptr, Dict(str, " - "Tensor) col, Dict(str, Tensor) nodes, Dict(str, bool) " - "return_edge_id) -> Dict(str, (Tensor, Tensor, Tensor?))")); + m.def("hetero_subgraph", hetero_subgraph); } } // namespace sampler diff --git a/pyg_lib/csrc/utils/hetero_dispatch.h b/pyg_lib/csrc/utils/hetero_dispatch.h index f7aa3d046..2569bae71 100644 --- a/pyg_lib/csrc/utils/hetero_dispatch.h +++ b/pyg_lib/csrc/utils/hetero_dispatch.h @@ -37,13 +37,11 @@ class HeteroDispatchArg {}; template class HeteroDispatchArg { public: + using ValueType = V; HeteroDispatchArg(const T& val) : val_(val) {} // If we pass the filter, we will obtain the value of the argument. - template - V value_by_edge(const K& key) { - return val_; - } + V value_by_edge(const EdgeType& edge) { return val_; } bool filter_by_edge(const EdgeType& edge) { return true; } @@ -55,15 +53,13 @@ class HeteroDispatchArg { template class HeteroDispatchArg { public: + using ValueType = V; HeteroDispatchArg(const T& val) : val_(val) { static_assert(is_c10_dict::value, "Should be a c10::dict"); } // Dict value lookup - template - V value_by_edge(const K& key) { - return val_.at(get_src(key)); - } + V value_by_edge(const EdgeType& edge) { return val_.at(get_src(edge)); } // Dict if key exists bool filter_by_edge(const EdgeType& edge) { @@ -78,14 +74,12 @@ class HeteroDispatchArg { template class HeteroDispatchArg { public: + using ValueType = V; HeteroDispatchArg(const T& val) : val_(val) { static_assert(is_c10_dict::value, "Should be a c10::dict"); } - template - V value_by_edge(const K& key) { - return val_.at(get_dst(key)); - } + V value_by_edge(const EdgeType& edge) { return val_.at(get_dst(edge)); } bool filter_by_edge(const EdgeType& edge) { return val_.contains(get_dst(edge)); @@ -99,14 +93,12 @@ class HeteroDispatchArg { template class HeteroDispatchArg { public: + using ValueType = V; HeteroDispatchArg(const T& val) : val_(val) { static_assert(is_c10_dict::value, "Should be a c10::dict"); } - template - V value_by_edge(const K& key) { - return val_.at(key); - } + V value_by_edge(const EdgeType& edge) { return val_.at(edge); } bool filter_by_edge(const EdgeType& edge) { return val_.contains(edge); } @@ -145,46 +137,27 @@ bool filter_args_by_edge(const EdgeType& edge, T&& t, Args&&... args) { return t.filter_by_edge(edge) && filter_args_by_edge(edge, args...); } -// Check if a callable is wrapped by std::function -template -struct is_std_function : std::false_type {}; - -template -struct is_std_function> : std::true_type {}; - -template -class HeteroDispatchOp { - public: - using ResultType = typename T::result_type; - HeteroDispatchOp(const EdgeTensorDict& rowptr, - const EdgeTensorDict& col, - T op) - : rowptr_(rowptr), col_(col), op_(op) { - // Check early - static_assert(is_std_function::value, "Must pass a function"); - } +// Specialize +template +auto value_args_by_edge(const EdgeType& edge, Args&&... args) {} - template - c10::Dict operator()(Args&&... args) { - c10::Dict dict; - for (const auto& kv : rowptr_) { - auto edge = kv.key(); - auto rowptr = kv.value(); - auto col = col_.at(edge); - bool pass = filter_args_by_edge(edge, args...); - if (pass) { - ResultType res = op_(rowptr, col, args.value_by_edge(edge)...); - dict.insert(edge, res); - } - } - return dict; - } +// Stop condition of argument filtering +template <> +auto value_args_by_edge(const EdgeType& edge) { + return std::tuple<>(); +} - private: - EdgeTensorDict rowptr_; - EdgeTensorDict col_; - T op_; -}; +// We filter each argument individually by the given edge using a variadic +// template +template +auto value_args_by_edge(const EdgeType& edge, T&& t, Args&&... args) { + using ArgType = std::remove_const_t>; + static_assert(is_hetero_arg::value, + "args should be HeteroDispatchArg"); + return std::tuple_cat( + std::tuple(t.value_by_edge(edge)), + value_args_by_edge(edge, args...)); +} } // namespace utils diff --git a/pyg_lib/csrc/utils/types.h b/pyg_lib/csrc/utils/types.h index 0a1ae3915..7e407f9d9 100644 --- a/pyg_lib/csrc/utils/types.h +++ b/pyg_lib/csrc/utils/types.h @@ -11,7 +11,7 @@ const std::string SPLIT_TOKEN = "__"; using EdgeType = std::string; using NodeType = std::string; -using rel_t = std::string; +using RelationType = std::string; using EdgeTensorDict = c10::Dict; using NodeTensorDict = c10::Dict; @@ -20,7 +20,7 @@ inline NodeType get_src(const EdgeType& e) { return e.substr(0, e.find_first_of(SPLIT_TOKEN)); } -inline rel_t get_rel(const EdgeType& e) { +inline RelationType get_rel(const EdgeType& e) { auto beg = e.find_first_of(SPLIT_TOKEN) + SPLIT_TOKEN.size(); return e.substr(beg, e.find_last_of(SPLIT_TOKEN) - SPLIT_TOKEN.size() + 1 - beg); From c4c446aa3ad6fa30bdd1ed7b1ad77aaf2a336055 Mon Sep 17 00:00:00 2001 From: ZenoTan Date: Sun, 15 May 2022 11:24:53 +0000 Subject: [PATCH 11/13] fix comments --- pyg_lib/csrc/sampler/subgraph.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pyg_lib/csrc/sampler/subgraph.cpp b/pyg_lib/csrc/sampler/subgraph.cpp index 7ee90a464..24316d9f7 100644 --- a/pyg_lib/csrc/sampler/subgraph.cpp +++ b/pyg_lib/csrc/sampler/subgraph.cpp @@ -73,16 +73,16 @@ hetero_subgraph(const utils::EdgeTensorDict& rowptr, for (const auto& kv : rowptr) { const auto& edge_type = kv.key(); - bool pass = filter_args_by_edge(edge_type, src_nodes_arg, dst_nodes_arg, - edge_id_arg); + bool pass = src_nodes_arg.filter_by_edge(edge_type) && + dst_nodes_arg.filter_by_edge(edge_type) && + edge_id_arg.filter_by_edge(edge_type); if (pass) { - auto vals = value_args_by_edge(edge_type, src_nodes_arg, dst_nodes_arg, - edge_id_arg); const auto& r = rowptr.at(edge_type); const auto& c = col.at(edge_type); - res.insert(edge_type, - subgraph_bipartite(r, c, std::get<0>(vals), std::get<1>(vals), - std::get<2>(vals))); + res.insert(edge_type, subgraph_bipartite( + r, c, src_nodes_arg.value_by_edge(edge_type), + dst_nodes_arg.value_by_edge(edge_type), + edge_id_arg.value_by_edge(edge_type))); } } From 487eaf6374a5f36615fb3da38e57b7ac2242ecec Mon Sep 17 00:00:00 2001 From: Zeyuan Tan <41138939+ZenoTan@users.noreply.github.com> Date: Sun, 15 May 2022 17:36:57 +0100 Subject: [PATCH 12/13] Update CHANGELOG.md Co-authored-by: Matthias Fey --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 72cd7346c..10c8ba8f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [Unreleased] ### Added -- Added hetero subgraph kernel ([#43](https://github.com/pyg-team/pyg-lib/pull/43) +- Added `hetero_subgraph` kernel ([#43](https://github.com/pyg-team/pyg-lib/pull/43) - Added `pyg::sampler::Mapper` utility for mapping global to local node indices ([#45](https://github.com/pyg-team/pyg-lib/pull/45) - Added benchmark script ([#45](https://github.com/pyg-team/pyg-lib/pull/45) - Added download script for benchmark data ([#44](https://github.com/pyg-team/pyg-lib/pull/44) From 48c115d5b0ab163755fbafbb0e69d491a52e12e5 Mon Sep 17 00:00:00 2001 From: Zeyuan Tan <41138939+ZenoTan@users.noreply.github.com> Date: Sun, 15 May 2022 17:37:07 +0100 Subject: [PATCH 13/13] Update test/csrc/utils/test_utils.cpp Co-authored-by: Matthias Fey --- test/csrc/utils/test_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/csrc/utils/test_utils.cpp b/test/csrc/utils/test_utils.cpp index 06ff2a803..cf5bf1fca 100644 --- a/test/csrc/utils/test_utils.cpp +++ b/test/csrc/utils/test_utils.cpp @@ -1,6 +1,6 @@ #include -#include "pyg_lib/csrc/sampler/subgraph.h" +#include "pyg_lib/csrc/utils/types.h" TEST(UtilsTypeTest, BasicAssertions) { pyg::utils::EdgeType edge = "node1__to__node2";