Skip to content

Commit

Permalink
Support various data types in CPUHashMap (#376)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jan 11, 2025
1 parent 78f07dc commit 10b73d9
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 17 deletions.
53 changes: 37 additions & 16 deletions pyg_lib/csrc/classes/cpu/hash_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,60 @@
namespace pyg {
namespace classes {

template <typename T>
struct CPUHashMap : torch::CustomClassHolder {
std::unordered_map<T, int64_t> map;
public:
using KeyType = std::
variant<bool, uint8_t, int8_t, int16_t, int32_t, int64_t, float, double>;

CPUHashMap(const at::Tensor& key) {
// TODO Assert 1-dim
const auto key_data = key.data_ptr<T>();
for (int64_t i = 0; i < key.numel(); ++i) {
// TODO Check that key does not yet exist.
map[key_data[i]] = i;
}

// clang-format off
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool,
key.scalar_type(),
"cpu_hash_map_init",
[&] {
const auto key_data = key.data_ptr<scalar_t>();
for (int64_t i = 0; i < key.numel(); ++i) {
// TODO Check that key does not yet exist.
map_[key_data[i]] = i;
}
});
// clang-format on
};

at::Tensor get(const at::Tensor& query) {
// TODO Assert 1-dim
const auto options = at::TensorOptions().dtype(at::kLong);
auto out = at::empty({query.numel()}, options);

const auto query_data = query.data_ptr<T>();
const auto options = at::TensorOptions().dtype(at::kLong);
const auto out = at::empty({query.numel()}, options);
auto out_data = out.data_ptr<int64_t>();

for (size_t i = 0; i < query.numel(); ++i) {
// TODO Insert -1 if key does not exist.
out_data[i] = map[query_data[i]];
}
// clang-format off
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool,
query.scalar_type(),
"cpu_hash_map_get",
[&] {
const auto query_data = query.data_ptr<scalar_t>();

for (size_t i = 0; i < query.numel(); ++i) {
// TODO Insert -1 if key does not exist.
out_data[i] = map_[query_data[i]];
}
});
// clang-format on

return out;
}

private:
std::unordered_map<KeyType, int64_t> map_;
};

TORCH_LIBRARY(pyg, m) {
m.class_<CPUHashMap<int64_t>>("CPULongHashMap")
m.class_<CPUHashMap>("CPUHashMap")
.def(torch::init<at::Tensor&>())
.def("get", &CPUHashMap<int64_t>::get);
.def("get", &CPUHashMap::get);
}

} // namespace classes
Expand Down
2 changes: 1 addition & 1 deletion test/csrc/classes/test_hash_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ TEST(CPUHashMapTest, BasicAssertions) {
auto options = at::TensorOptions().dtype(at::kLong);
auto key = at::tensor({0, 10, 30, 20}, options);

auto map = pyg::classes::CPUHashMap<int64_t>(key);
auto map = pyg::classes::CPUHashMap(key);

auto query = at::tensor({30, 10, 20}, options);
auto expected = at::tensor({2, 1, 3}, options);
Expand Down

0 comments on commit 10b73d9

Please sign in to comment.