Skip to content

Commit

Permalink
Pybind support for CPUHashMap (#378)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jan 11, 2025
1 parent a27f5bf commit c0fa823
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 55 deletions.
5 changes: 5 additions & 0 deletions .ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ select = [
]
ignore = [
"D100", # TODO Don't ignore "Missing docstring in public module"
"D101", # TODO Don't ignore "Missing docstring in public class"
"D102", # TODO Don't ignore "Missing docstring in public method"
"D103", # TODO Don't ignore "Missing docstring in public function"
"D104", # TODO Don't ignore "Missing docstring in public package"
"D105", # Ignore "Missing docstring in magic method"
"D107", # Ignore "Missing docstring in __init__"
"D205", # Ignore "blank line required between summary line and description"
]

Expand Down
3 changes: 3 additions & 0 deletions benchmark/classes/hash_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pyg_lib.classes import HashMap

print(HashMap)
1 change: 1 addition & 0 deletions pyg_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def load_library(lib_name: str) -> None:
load_library('libpyg')

import pyg_lib.ops # noqa
import pyg_lib.classes # noqa
import pyg_lib.partition # noqa
import pyg_lib.sampler # noqa

Expand Down
18 changes: 18 additions & 0 deletions pyg_lib/classes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
from torch import Tensor


class HashMap:
def __init__(self, key: Tensor) -> Tensor:
self._map = torch.classes.pyg.CPUHashMap(key)

def get(self, query: Tensor) -> Tensor:
return self._map.get(query)

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'


__all__ = [
'HashMap',
]
64 changes: 64 additions & 0 deletions pyg_lib/csrc/classes/cpu/hash_map.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "hash_map.h"

#include <torch/library.h>

namespace pyg {
namespace classes {

CPUHashMap::CPUHashMap(const at::Tensor& key) {
at::TensorArg key_arg{key, "key", 0};
at::CheckedFrom c{"HashMap.init"};
at::checkDeviceType(c, key, at::DeviceType::CPU);
at::checkDim(c, key_arg, 1);
at::checkContiguous(c, key_arg);

// clang-format off
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool,
key.scalar_type(),
"cpu_hash_map_init",
[&] {
const auto key_data = key.data_ptr<scalar_t>();
for (int64_t i = 0; i < key.numel(); ++i) {
auto [iterator, inserted] = map_.insert({key_data[i], i});
TORCH_CHECK(inserted, "Found duplicated key.");
}
});
// clang-format on
};

at::Tensor CPUHashMap::get(const at::Tensor& query) {
at::TensorArg query_arg{query, "query", 0};
at::CheckedFrom c{"HashMap.get"};
at::checkDeviceType(c, query, at::DeviceType::CPU);
at::checkDim(c, query_arg, 1);
at::checkContiguous(c, query_arg);

const auto options = at::TensorOptions().dtype(at::kLong);
const auto out = at::empty({query.numel()}, options);
auto out_data = out.data_ptr<int64_t>();

// clang-format off
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool,
query.scalar_type(),
"cpu_hash_map_get",
[&] {
const auto query_data = query.data_ptr<scalar_t>();

for (size_t i = 0; i < query.numel(); ++i) {
auto it = map_.find(query_data[i]);
out_data[i] = (it != map_.end()) ? it->second : -1;
}
});
// clang-format on

return out;
}

TORCH_LIBRARY(pyg, m) {
m.class_<CPUHashMap>("CPUHashMap")
.def(torch::init<at::Tensor&>())
.def("get", &CPUHashMap::get);
}

} // namespace classes
} // namespace pyg
59 changes: 4 additions & 55 deletions pyg_lib/csrc/classes/cpu/hash_map.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <torch/library.h>
#include <ATen/ATen.h>
#include <variant>

namespace pyg {
namespace classes {
Expand All @@ -10,64 +11,12 @@ struct CPUHashMap : torch::CustomClassHolder {
using KeyType = std::
variant<bool, uint8_t, int8_t, int16_t, int32_t, int64_t, float, double>;

CPUHashMap(const at::Tensor& key) {
at::TensorArg key_arg{key, "key", 0};
at::CheckedFrom c{"HashMap.init"};
at::checkDeviceType(c, key, at::DeviceType::CPU);
at::checkDim(c, key_arg, 1);
at::checkContiguous(c, key_arg);

// clang-format off
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool,
key.scalar_type(),
"cpu_hash_map_init",
[&] {
const auto key_data = key.data_ptr<scalar_t>();
for (int64_t i = 0; i < key.numel(); ++i) {
auto [iterator, inserted] = map_.insert({key_data[i], i});
TORCH_CHECK(inserted, "Found duplicated key.");
}
});
// clang-format on
};

at::Tensor get(const at::Tensor& query) {
at::TensorArg query_arg{query, "query", 0};
at::CheckedFrom c{"HashMap.get"};
at::checkDeviceType(c, query, at::DeviceType::CPU);
at::checkDim(c, query_arg, 1);
at::checkContiguous(c, query_arg);

const auto options = at::TensorOptions().dtype(at::kLong);
const auto out = at::empty({query.numel()}, options);
auto out_data = out.data_ptr<int64_t>();

// clang-format off
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool,
query.scalar_type(),
"cpu_hash_map_get",
[&] {
const auto query_data = query.data_ptr<scalar_t>();

for (size_t i = 0; i < query.numel(); ++i) {
auto it = map_.find(query_data[i]);
out_data[i] = (it != map_.end()) ? it->second : -1;
}
});
// clang-format on

return out;
}
CPUHashMap(const at::Tensor& key);
at::Tensor get(const at::Tensor& query);

private:
std::unordered_map<KeyType, int64_t> map_;
};

TORCH_LIBRARY(pyg, m) {
m.class_<CPUHashMap>("CPUHashMap")
.def(torch::init<at::Tensor&>())
.def("get", &CPUHashMap::get);
}

} // namespace classes
} // namespace pyg

0 comments on commit c0fa823

Please sign in to comment.