Skip to content

Commit

Permalink
Disable Windows support in CUDAHashMap (#398)
Browse files Browse the repository at this point in the history
`cuCollections` doesn't provide native Windows support :(
  • Loading branch information
rusty1s authored Feb 6, 2025
1 parent 53cadc7 commit 08c61e9
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 111 deletions.
6 changes: 0 additions & 6 deletions .github/actions/setup/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ runs:
sed -i '1s/^/#if defined(__linux__) \&\& defined(__x86_64__)\n__asm__(".symver pow,pow@GLIBC_2.2.5");\n#endif\n/' third_party/METIS/libmetis/metislib.h
shell: bash

- name: Fix cuCollections
if: ${{ (inputs.cuda-version != 'cpu') && (runner.os == 'Windows') }}
run: |
sed -i '37s|#define CUCO_CUDA_MINIMUM_ARCH .*|#define CUCO_CUDA_MINIMUM_ARCH 600|' third_party/cuCollections/include/cuco/detail/__config
shell: bash

- name: Install additional dependencies
run: |
pip install setuptools ninja wheel
Expand Down
32 changes: 1 addition & 31 deletions .github/workflows/building.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
python-version: ['3.9', '3.10', '3.11', '3.12']
# torch-version: [1.13.0, 2.0.0, 2.1.0, 2.2.0, 2.3.0, 2.4.0, 2.5.0]
torch-version: [2.5.0]
cuda-version: ['cpu', 'cu113', 'cu116', 'cu117', 'cu118', 'cu121', 'cu124']
cuda-version: ['cpu', 'cu117', 'cu118', 'cu121', 'cu124']
exclude:
- torch-version: 1.13.0
python-version: '3.12'
Expand All @@ -24,60 +24,30 @@ jobs:
python-version: '3.12'
- torch-version: 1.13.0
python-version: '3.11'
- torch-version: 1.13.0
cuda-version: 'cu113'
- torch-version: 1.13.0
cuda-version: 'cu118'
- torch-version: 1.13.0
cuda-version: 'cu121'
- torch-version: 1.13.0
cuda-version: 'cu124'
- torch-version: 2.0.0
cuda-version: 'cu113'
- torch-version: 2.0.0
cuda-version: 'cu116'
- torch-version: 2.0.0
cuda-version: 'cu124'
- torch-version: 2.1.0
cuda-version: 'cu113'
- torch-version: 2.1.0
cuda-version: 'cu116'
- torch-version: 2.1.0
cuda-version: 'cu117'
- torch-version: 2.1.0
cuda-version: 'cu124'
- torch-version: 2.2.0
cuda-version: 'cu113'
- torch-version: 2.2.0
cuda-version: 'cu116'
- torch-version: 2.2.0
cuda-version: 'cu117'
- torch-version: 2.2.0
cuda-version: 'cu124'
- torch-version: 2.3.0
cuda-version: 'cu113'
- torch-version: 2.3.0
cuda-version: 'cu116'
- torch-version: 2.3.0
cuda-version: 'cu117'
- torch-version: 2.3.0
cuda-version: 'cu124'
- torch-version: 2.4.0
cuda-version: 'cu113'
- torch-version: 2.4.0
cuda-version: 'cu116'
- torch-version: 2.4.0
cuda-version: 'cu117'
- torch-version: 2.5.0
cuda-version: 'cu113'
- torch-version: 2.5.0
cuda-version: 'cu116'
- torch-version: 2.5.0
cuda-version: 'cu117'
- os: macos-14
cuda-version: 'cu113'
- os: macos-14
cuda-version: 'cu116'
- os: macos-14
cuda-version: 'cu117'
- os: macos-14
Expand Down
32 changes: 1 addition & 31 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
os: [ubuntu-20.04, macos-14, windows-2019]
python-version: ['3.9', '3.10', '3.11', '3.12']
torch-version: [1.13.0, 2.0.0, 2.1.0, 2.2.0, 2.3.0, 2.4.0, 2.5.0]
cuda-version: ['cpu', 'cu113', 'cu116', 'cu117', 'cu118', 'cu121', 'cu124']
cuda-version: ['cpu', 'cu117', 'cu118', 'cu121', 'cu124']
exclude:
- torch-version: 1.13.0
python-version: '3.12'
Expand All @@ -27,60 +27,30 @@ jobs:
python-version: '3.12'
- torch-version: 1.13.0
python-version: '3.11'
- torch-version: 1.13.0
cuda-version: 'cu113'
- torch-version: 1.13.0
cuda-version: 'cu118'
- torch-version: 1.13.0
cuda-version: 'cu121'
- torch-version: 1.13.0
cuda-version: 'cu124'
- torch-version: 2.0.0
cuda-version: 'cu113'
- torch-version: 2.0.0
cuda-version: 'cu116'
- torch-version: 2.0.0
cuda-version: 'cu124'
- torch-version: 2.1.0
cuda-version: 'cu113'
- torch-version: 2.1.0
cuda-version: 'cu116'
- torch-version: 2.1.0
cuda-version: 'cu117'
- torch-version: 2.1.0
cuda-version: 'cu124'
- torch-version: 2.2.0
cuda-version: 'cu113'
- torch-version: 2.2.0
cuda-version: 'cu116'
- torch-version: 2.2.0
cuda-version: 'cu117'
- torch-version: 2.2.0
cuda-version: 'cu124'
- torch-version: 2.3.0
cuda-version: 'cu113'
- torch-version: 2.3.0
cuda-version: 'cu116'
- torch-version: 2.3.0
cuda-version: 'cu117'
- torch-version: 2.3.0
cuda-version: 'cu124'
- torch-version: 2.4.0
cuda-version: 'cu113'
- torch-version: 2.4.0
cuda-version: 'cu116'
- torch-version: 2.4.0
cuda-version: 'cu117'
- torch-version: 2.5.0
cuda-version: 'cu113'
- torch-version: 2.5.0
cuda-version: 'cu116'
- torch-version: 2.5.0
cuda-version: 'cu117'
- os: macos-14
cuda-version: 'cu113'
- os: macos-14
cuda-version: 'cu116'
- os: macos-14
cuda-version: 'cu117'
- os: macos-14
Expand Down
82 changes: 41 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,47 +34,47 @@ where

The following combinations are supported:

| PyTorch 2.5 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** || | | | |||
| **Windows** || | | | |||
| **macOS** || | | | | | |

| PyTorch 2.4 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** || | | | |||
| **Windows** || | | | |||
| **macOS** || | | | | | |

| PyTorch 2.3 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** || | | | || |
| **Windows** || | | | || |
| **macOS** || | | | | | |

| PyTorch 2.2 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** || | | | || |
| **Windows** || | | | || |
| **macOS** || | | | | | |

| PyTorch 2.1 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** || | | | || |
| **Windows** || | | | || |
| **macOS** || | | | | | |

| PyTorch 2.0 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** || | | ||| |
| **Windows** || | | || | |
| **macOS** || | | | | | |

| PyTorch 1.13 | `cpu` | `cu113` | `cu116` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|---------|---------|
| **Linux** || || | | | |
| **Windows** || || | | | |
| **macOS** || | | | | | |
| PyTorch 2.5 | `cpu` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|
| **Linux** || ||||
| **Windows** || ||||
| **macOS** || | | | |

| PyTorch 2.4 | `cpu` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|
| **Linux** || ||||
| **Windows** || ||||
| **macOS** || | | | |

| PyTorch 2.3 | `cpu` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|
| **Linux** || ||| |
| **Windows** || ||| |
| **macOS** || | | | |

| PyTorch 2.2 | `cpu` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|
| **Linux** || ||| |
| **Windows** || ||| |
| **macOS** || | | | |

| PyTorch 2.1 | `cpu` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|
| **Linux** || ||| |
| **Windows** || ||| |
| **macOS** || | | | |

| PyTorch 2.0 | `cpu` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|
| **Linux** ||||| |
| **Windows** |||| | |
| **macOS** || | | | |

| PyTorch 1.13 | `cpu` | `cu117` | `cu118` | `cu121` | `cu124` |
|--------------|-------|---------|---------|---------|---------|
| **Linux** ||| | | |
| **Windows** ||| | | |
| **macOS** || | | | |

### Form nightly

Expand Down
25 changes: 23 additions & 2 deletions pyg_lib/csrc/classes/cuda/hash_map.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#include <ATen/ATen.h>
#include <torch/library.h>
#include <cuco/static_map.cuh>
#include <limits>

#ifndef _WIN32
#include <cuco/static_map.cuh>
#endif

namespace pyg {
namespace classes {

Expand All @@ -22,6 +25,7 @@ struct HashMapImpl {
virtual at::Tensor keys() = 0;
};

#ifndef _WIN32
template <typename KeyType>
struct CUDAHashMapImpl : HashMapImpl {
public:
Expand Down Expand Up @@ -86,10 +90,12 @@ struct CUDAHashMapImpl : HashMapImpl {
private:
std::unique_ptr<cuco::static_map<KeyType, ValueType>> map_;
};
#endif

struct CUDAHashMap : torch::CustomClassHolder {
public:
CUDAHashMap(const at::Tensor& key, double load_factor = 0.5) {
#ifndef _WIN32
at::TensorArg key_arg{key, "key", 0};
at::CheckedFrom c{"CUDAHashMap.init"};
at::checkDeviceType(c, key, at::DeviceType::CUDA);
Expand All @@ -99,22 +105,37 @@ struct CUDAHashMap : torch::CustomClassHolder {
DISPATCH_KEY(key.scalar_type(), "cuda_hash_map_init", [&] {
map_ = std::make_unique<CUDAHashMapImpl<scalar_t>>(key, load_factor);
});
#else
TORCH_CHECK(false, "'CUDAHashMap' not supported on Windows");
#endif
}

at::Tensor get(const at::Tensor& query) {
#ifndef _WIN32
at::TensorArg query_arg{query, "query", 0};
at::CheckedFrom c{"CUDAHashMap.get"};
at::checkDeviceType(c, query, at::DeviceType::CUDA);
at::checkDim(c, query_arg, 1);
at::checkContiguous(c, query_arg);

return map_->get(query);
#else
TORCH_CHECK(false, "'CUDAHashMap' not supported on Windows");
#endif
}

at::Tensor keys() { return map_->keys(); }
at::Tensor keys() {
#ifndef _WIN32
return map_->keys();
#else
TORCH_CHECK(false, "'CUDAHashMap' not supported on Windows");
#endif
}

private:
#ifndef _WIN32
std::unique_ptr<HashMapImpl> map_;
#endif
};

} // namespace
Expand Down

0 comments on commit 08c61e9

Please sign in to comment.