Skip to content

Commit

Permalink
Robustify CUDAHashMap implementation + add serialization (#392)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Feb 1, 2025
1 parent ffc2398 commit ef72b03
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 30 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/cuda/Linux-env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,42 @@ case ${1} in
cu124)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-12.4/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="5.0+PTX;6.0;7.0;7.5;8.0;8.6;9.0"
export TORCH_CUDA_ARCH_LIST="6.0+PTX;7.0;7.5;8.0;8.6;9.0"
;;
cu121)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-12.1/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="5.0+PTX;6.0;7.0;7.5;8.0;8.6;9.0"
export TORCH_CUDA_ARCH_LIST="6.0+PTX;7.0;7.5;8.0;8.6;9.0"
;;
cu118)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-11.8/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6;9.0"
export TORCH_CUDA_ARCH_LIST="6.0+PTX;7.0;7.5;8.0;8.6;9.0"
;;
cu117)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-11.7/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
export TORCH_CUDA_ARCH_LIST="6.0+PTX;7.0;7.5;8.0;8.6"
;;
cu116)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-11.6/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
export TORCH_CUDA_ARCH_LIST="6.0+PTX;7.0;7.5;8.0;8.6"
;;
cu115)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-11.5/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
export TORCH_CUDA_ARCH_LIST="6.0+PTX;7.0;7.5;8.0;8.6"
;;
cu113)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-11.3/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
export TORCH_CUDA_ARCH_LIST="6.0+PTX;7.0;7.5;8.0;8.6"
;;
cu102)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-10.2/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5"
export TORCH_CUDA_ARCH_LIST="6.0+PTX;7.0;7.5"
;;
*)
;;
Expand Down
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.15)
cmake_minimum_required(VERSION 3.18)
project(pyg)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Expand Down Expand Up @@ -43,6 +43,7 @@ if(WITH_CUDA)
enable_language(CUDA)
add_definitions(-DWITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr -allow-unsupported-compiler")
set(CMAKE_CUDA_ARCHITECTURES "60;70;75;80;86;90")

if (NOT "$ENV{EXTERNAL_CUTLASS_INCLUDE_DIR}" STREQUAL "")
include_directories($ENV{EXTERNAL_CUTLASS_INCLUDE_DIR})
Expand Down
112 changes: 91 additions & 21 deletions pyg_lib/csrc/classes/cuda/hash_map_impl.cu
Original file line number Diff line number Diff line change
@@ -1,69 +1,139 @@
#include <ATen/ATen.h>
#include <torch/library.h>
#include <cuco/static_map.cuh>

#include "../hash_map_impl.h"
#include <limits>

namespace pyg {
namespace classes {

namespace {

#define DISPATCH_CASE_KEY(...) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)

#define DISPATCH_KEY(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_KEY(__VA_ARGS__))

struct HashMapImpl {
virtual ~HashMapImpl() = default;
virtual at::Tensor get(const at::Tensor& query) = 0;
virtual at::Tensor keys() = 0;
};

template <typename KeyType>
struct CUDAHashMapImpl : HashMapImpl {
public:
using ValueType = int64_t;

CUDAHashMapImpl(const at::Tensor& key) {
KeyType constexpr empty_key_sentinel = -1; // TODO
KeyType constexpr empty_key_sentinel = std::numeric_limits<KeyType>::min();
ValueType constexpr empty_value_sentinel = -1;

map_ = std::make_unique<cuco::static_map<KeyType, ValueType>>(
2 * key.numel(), // loader_factor = 0.5
2 * key.numel(), // load_factor = 0.5
cuco::empty_key{empty_key_sentinel},
cuco::empty_value{empty_value_sentinel});

const auto key_data = key.data_ptr<KeyType>();
const auto options =
at::TensorOptions().device(key.device()).dtype(at::kLong);
key.options().dtype(c10::CppTypeToScalarType<ValueType>::value);
const auto value = at::arange(key.numel(), options);
const auto key_data = key.data_ptr<KeyType>();
const auto value_data = value.data_ptr<ValueType>();
const auto zipped =
thrust::make_zip_iterator(thrust::make_tuple(key_data, value_data));

map_->insert(key_data, value_data, key.numel());
map_->insert(zipped, zipped + key.numel());
}

at::Tensor get(const at::Tensor& query) override {
const auto options =
at::TensorOptions().device(query.device()).dtype(at::kLong);
query.options().dtype(c10::CppTypeToScalarType<ValueType>::value);
const auto out = at::empty({query.numel()}, options);
const auto query_data = query.data_ptr<KeyType>();
auto out_data = out.data_ptr<int64_t>();
const auto out_data = out.data_ptr<ValueType>();

map_->find(query_data, out_data, query.numel());
map_->find(query_data, query_data + query.numel(), out_data);

return out;
}

at::Tensor keys() override {
// TODO This will not work in multi-GPU scenarios.
const auto options = at::TensorOptions()
.device(at::DeviceType::CUDA)
.dtype(c10::CppTypeToScalarType<ValueType>::value);
const auto size = static_cast<int64_t>(map_->size());

at::Tensor key;
if (std::is_same<KeyType, int16_t>::value) {
key = at::empty({size}, options.dtype(at::kShort));
} else if (std::is_same<KeyType, int32_t>::value) {
key = at::empty({size}, options.dtype(at::kInt));
} else {
key = at::empty({size}, options);
}
const auto value = at::empty({size}, options);
const auto key_data = key.data_ptr<KeyType>();
const auto value_data = value.data_ptr<ValueType>();

map_->retrieve_all(key_data, value_data);

return key.index_select(0, value.argsort());
}

private:
std::unique_ptr<cuco::static_map<KeyType, ValueType>> map_;
};

// template struct CUDAHashMapImpl<bool>;
// template struct CUDAHashMapImpl<uint8_t>;
// template struct CUDAHashMapImpl<int8_t>;
// template struct CUDAHashMapImpl<int16_t>;
// template struct CUDAHashMapImpl<int32_t>;
// template struct CUDAHashMapImpl<int64_t>;
// template struct CUDAHashMapImpl<float>;
// template struct CUDAHashMapImpl<double>;

struct CUDAHashMap : torch::CustomClassHolder {
public:
CUDAHashMap(const at::Tensor& key) {}
CUDAHashMap(const at::Tensor& key) {
at::TensorArg key_arg{key, "key", 0};
at::CheckedFrom c{"CUDAHashMap.init"};
at::checkDeviceType(c, key, at::DeviceType::CUDA);
at::checkDim(c, key_arg, 1);
at::checkContiguous(c, key_arg);

DISPATCH_KEY(key.scalar_type(), "cuda_hash_map_init", [&] {
map_ = std::make_unique<CUDAHashMapImpl<scalar_t>>(key);
});
}

at::Tensor get(const at::Tensor& query) {
at::TensorArg query_arg{query, "query", 0};
at::CheckedFrom c{"CUDAHashMap.get"};
at::checkDeviceType(c, query, at::DeviceType::CUDA);
at::checkDim(c, query_arg, 1);
at::checkContiguous(c, query_arg);

at::Tensor get(const at::Tensor& query) { return query; }
return map_->get(query);
}

at::Tensor keys() { return map_->keys(); }

private:
std::unique_ptr<HashMapImpl> map_;
};

} // namespace

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.class_<CUDAHashMap>("CUDAHashMap")
.def(torch::init<at::Tensor&>())
.def("get", &CUDAHashMap::get)
.def("keys", &CUDAHashMap::keys)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<CUDAHashMap>& self) -> at::Tensor {
return self->keys();
},
// __setstate__
[](const at::Tensor& state) -> c10::intrusive_ptr<CUDAHashMap> {
return c10::make_intrusive<CUDAHashMap>(state);
});
}

} // namespace classes
} // namespace pyg

0 comments on commit ef72b03

Please sign in to comment.