Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable Windows support in CUDAHashMap #398

Merged
merged 2 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .github/actions/setup/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ runs:
sed -i '1s/^/#if defined(__linux__) \&\& defined(__x86_64__)\n__asm__(".symver pow,pow@GLIBC_2.2.5");\n#endif\n/' third_party/METIS/libmetis/metislib.h
shell: bash

- name: Fix cuCollections
if: ${{ (inputs.cuda-version != 'cpu') && (runner.os == 'Windows') }}
run: |
sed -i '37s|#define CUCO_CUDA_MINIMUM_ARCH .*|#define CUCO_CUDA_MINIMUM_ARCH 600|' third_party/cuCollections/include/cuco/detail/__config
shell: bash

- name: Install additional dependencies
run: |
pip install setuptools ninja wheel
Expand Down
32 changes: 1 addition & 31 deletions .github/workflows/building.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
python-version: ['3.9', '3.10', '3.11', '3.12']
# torch-version: [1.13.0, 2.0.0, 2.1.0, 2.2.0, 2.3.0, 2.4.0, 2.5.0]
torch-version: [2.5.0]
cuda-version: ['cpu', 'cu113', 'cu116', 'cu117', 'cu118', 'cu121', 'cu124']
cuda-version: ['cpu', 'cu117', 'cu118', 'cu121', 'cu124']
exclude:
- torch-version: 1.13.0
python-version: '3.12'
Expand All @@ -24,60 +24,30 @@ jobs:
python-version: '3.12'
- torch-version: 1.13.0
python-version: '3.11'
- torch-version: 1.13.0
cuda-version: 'cu113'
- torch-version: 1.13.0
cuda-version: 'cu118'
- torch-version: 1.13.0
cuda-version: 'cu121'
- torch-version: 1.13.0
cuda-version: 'cu124'
- torch-version: 2.0.0
cuda-version: 'cu113'
- torch-version: 2.0.0
cuda-version: 'cu116'
- torch-version: 2.0.0
cuda-version: 'cu124'
- torch-version: 2.1.0
cuda-version: 'cu113'
- torch-version: 2.1.0
cuda-version: 'cu116'
- torch-version: 2.1.0
cuda-version: 'cu117'
- torch-version: 2.1.0
cuda-version: 'cu124'
- torch-version: 2.2.0
cuda-version: 'cu113'
- torch-version: 2.2.0
cuda-version: 'cu116'
- torch-version: 2.2.0
cuda-version: 'cu117'
- torch-version: 2.2.0
cuda-version: 'cu124'
- torch-version: 2.3.0
cuda-version: 'cu113'
- torch-version: 2.3.0
cuda-version: 'cu116'
- torch-version: 2.3.0
cuda-version: 'cu117'
- torch-version: 2.3.0
cuda-version: 'cu124'
- torch-version: 2.4.0
cuda-version: 'cu113'
- torch-version: 2.4.0
cuda-version: 'cu116'
- torch-version: 2.4.0
cuda-version: 'cu117'
- torch-version: 2.5.0
cuda-version: 'cu113'
- torch-version: 2.5.0
cuda-version: 'cu116'
- torch-version: 2.5.0
cuda-version: 'cu117'
- os: macos-14
cuda-version: 'cu113'
- os: macos-14
cuda-version: 'cu116'
- os: macos-14
cuda-version: 'cu117'
- os: macos-14
Expand Down
32 changes: 1 addition & 31 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
os: [ubuntu-20.04, macos-14, windows-2019]
python-version: ['3.9', '3.10', '3.11', '3.12']
torch-version: [1.13.0, 2.0.0, 2.1.0, 2.2.0, 2.3.0, 2.4.0, 2.5.0]
cuda-version: ['cpu', 'cu113', 'cu116', 'cu117', 'cu118', 'cu121', 'cu124']
cuda-version: ['cpu', 'cu117', 'cu118', 'cu121', 'cu124']
exclude:
- torch-version: 1.13.0
python-version: '3.12'
Expand All @@ -27,60 +27,30 @@ jobs:
python-version: '3.12'
- torch-version: 1.13.0
python-version: '3.11'
- torch-version: 1.13.0
cuda-version: 'cu113'
- torch-version: 1.13.0
cuda-version: 'cu118'
- torch-version: 1.13.0
cuda-version: 'cu121'
- torch-version: 1.13.0
cuda-version: 'cu124'
- torch-version: 2.0.0
cuda-version: 'cu113'
- torch-version: 2.0.0
cuda-version: 'cu116'
- torch-version: 2.0.0
cuda-version: 'cu124'
- torch-version: 2.1.0
cuda-version: 'cu113'
- torch-version: 2.1.0
cuda-version: 'cu116'
- torch-version: 2.1.0
cuda-version: 'cu117'
- torch-version: 2.1.0
cuda-version: 'cu124'
- torch-version: 2.2.0
cuda-version: 'cu113'
- torch-version: 2.2.0
cuda-version: 'cu116'
- torch-version: 2.2.0
cuda-version: 'cu117'
- torch-version: 2.2.0
cuda-version: 'cu124'
- torch-version: 2.3.0
cuda-version: 'cu113'
- torch-version: 2.3.0
cuda-version: 'cu116'
- torch-version: 2.3.0
cuda-version: 'cu117'
- torch-version: 2.3.0
cuda-version: 'cu124'
- torch-version: 2.4.0
cuda-version: 'cu113'
- torch-version: 2.4.0
cuda-version: 'cu116'
- torch-version: 2.4.0
cuda-version: 'cu117'
- torch-version: 2.5.0
cuda-version: 'cu113'
- torch-version: 2.5.0
cuda-version: 'cu116'
- torch-version: 2.5.0
cuda-version: 'cu117'
- os: macos-14
cuda-version: 'cu113'
- os: macos-14
cuda-version: 'cu116'
- os: macos-14
cuda-version: 'cu117'
- os: macos-14
Expand Down
82 changes: 41 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,47 +34,47 @@ where

The following combinations are supported:

| PyTorch 2.5 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** | ✅ | | | | ✅ | ✅ | ✅ |
| **Windows** | ✅ | | | | ✅ | ✅ | ✅ |
| **macOS** | ✅ | | | | | | |

| PyTorch 2.4 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** | ✅ | | | | ✅ | ✅ | ✅ |
| **Windows** | ✅ | | | | ✅ | ✅ | ✅ |
| **macOS** | ✅ | | | | | | |

| PyTorch 2.3 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** | ✅ | | | | ✅ | ✅ | |
| **Windows** | ✅ | | | | ✅ | ✅ | |
| **macOS** | ✅ | | | | | | |

| PyTorch 2.2 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** | ✅ | | | | ✅ | ✅ | |
| **Windows** | ✅ | | | | ✅ | ✅ | |
| **macOS** | ✅ | | | | | | |

| PyTorch 2.1 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** | ✅ | | | | ✅ | ✅ | |
| **Windows** | ✅ | | | | ✅ | ✅ | |
| **macOS** | ✅ | | | | | | |

| PyTorch 2.0 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** | ✅ | | | ✅ | ✅ | ✅ | |
| **Windows** | ✅ | | | ✅ | ✅ | | |
| **macOS** | ✅ | | | | | | |

| PyTorch 1.13 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** | ✅ | | ✅ | ✅ | | | |
| **Windows** | ✅ | | ✅ | ✅ | | | |
| **macOS** | ✅ | | | | | | |
| PyTorch 2.5 | `cpu` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|
| **Linux** | ✅ | | ✅ | ✅ | ✅ |
| **Windows** | ✅ | | ✅ | ✅ | ✅ |
| **macOS** | ✅ | | | | |

| PyTorch 2.4 | `cpu` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|
| **Linux** | ✅ | | ✅ | ✅ | ✅ |
| **Windows** | ✅ | | ✅ | ✅ | ✅ |
| **macOS** | ✅ | | | | |

| PyTorch 2.3 | `cpu` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|
| **Linux** | ✅ | | ✅ | ✅ | |
| **Windows** | ✅ | | ✅ | ✅ | |
| **macOS** | ✅ | | | | |

| PyTorch 2.2 | `cpu` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|
| **Linux** | ✅ | | ✅ | ✅ | |
| **Windows** | ✅ | | ✅ | ✅ | |
| **macOS** | ✅ | | | | |

| PyTorch 2.1 | `cpu` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|
| **Linux** | ✅ | | ✅ | ✅ | |
| **Windows** | ✅ | | ✅ | ✅ | |
| **macOS** | ✅ | | | | |

| PyTorch 2.0 | `cpu` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|
| **Linux** | ✅ | ✅ | ✅ | ✅ | |
| **Windows** | ✅ | ✅ | ✅ | | |
| **macOS** | ✅ | | | | |

| PyTorch 1.13 | `cpu` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|
| **Linux** | ✅ | ✅ | | | |
| **Windows** | ✅ | ✅ | | | |
| **macOS** | ✅ | | | | |

### Form nightly

Expand Down
25 changes: 23 additions & 2 deletions pyg_lib/csrc/classes/cuda/hash_map.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#include <ATen/ATen.h>
#include <torch/library.h>
#include <cuco/static_map.cuh>
#include <limits>

#ifndef _WIN32
#include <cuco/static_map.cuh>
#endif

namespace pyg {
namespace classes {

Expand All @@ -22,6 +25,7 @@ struct HashMapImpl {
virtual at::Tensor keys() = 0;
};

#ifndef _WIN32
template <typename KeyType>
struct CUDAHashMapImpl : HashMapImpl {
public:
Expand Down Expand Up @@ -86,10 +90,12 @@ struct CUDAHashMapImpl : HashMapImpl {
private:
std::unique_ptr<cuco::static_map<KeyType, ValueType>> map_;
};
#endif

struct CUDAHashMap : torch::CustomClassHolder {
public:
CUDAHashMap(const at::Tensor& key, double load_factor = 0.5) {
#ifndef _WIN32
at::TensorArg key_arg{key, "key", 0};
at::CheckedFrom c{"CUDAHashMap.init"};
at::checkDeviceType(c, key, at::DeviceType::CUDA);
Expand All @@ -99,22 +105,37 @@ struct CUDAHashMap : torch::CustomClassHolder {
DISPATCH_KEY(key.scalar_type(), "cuda_hash_map_init", [&] {
map_ = std::make_unique<CUDAHashMapImpl<scalar_t>>(key, load_factor);
});
#else
TORCH_CHECK(false, "'CUDAHashMap' not supported on Windows");
#endif
}

at::Tensor get(const at::Tensor& query) {
#ifndef _WIN32
at::TensorArg query_arg{query, "query", 0};
at::CheckedFrom c{"CUDAHashMap.get"};
at::checkDeviceType(c, query, at::DeviceType::CUDA);
at::checkDim(c, query_arg, 1);
at::checkContiguous(c, query_arg);

return map_->get(query);
#else
TORCH_CHECK(false, "'CUDAHashMap' not supported on Windows");
#endif
}

at::Tensor keys() { return map_->keys(); }
at::Tensor keys() {
#ifndef _WIN32
return map_->keys();
#else
TORCH_CHECK(false, "'CUDAHashMap' not supported on Windows");
#endif
}

private:
#ifndef _WIN32
std::unique_ptr<HashMapImpl> map_;
#endif
};

} // namespace
Expand Down
Loading