Skip to content

Commit

Permalink
reworked custom CUDA extension option in setup.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Aug 27, 2024
1 parent ff9beae commit 7b20621
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 21 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ Download directly from PyPI:
```bash
pip install torch-harmonics
```
If you would like to have accelerated CUDA extensions for the discrete-continuous convolutions, please use the '--cuda_ext' flag:
If you would like to enforce the compilation of CUDA extensions for the discrete-continuous convolutions, you can do so by setting the `FORCE_CUDA_EXTENSION` flag. You may also want to set appropriate architectures with the `TORCH_CUDA_ARCH_LIST` flag.
```bash
pip install --global-option --cuda_ext torch-harmonics
export FORCE_CUDA_EXTENSION=1
export TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX"
pip install torch-harmonics
```
:warning: Please note that the custom CUDA extensions currently only support CUDA architectures >= 7.0.

Expand Down
51 changes: 32 additions & 19 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,44 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import sys
import os, sys
import warnings

from setuptools import setup, find_packages
from setuptools.command.install import install

# some code to handle the building of custom modules
FORCE_CUDA_EXTENSION = os.getenv("FORCE_CUDA_EXTENSION", "0") == "1"
BUILD_CPP = BUILD_CUDA = False

# try to import torch
try:
from setuptools import setup, find_packages
except ImportError:
from distutils.core import setup, find_packages
import torch

import torch
from torch.utils import cpp_extension
print(f"setup.py with torch {torch.__version__}")
from torch.utils.cpp_extension import BuildExtension, CppExtension

BUILD_CPP = True
from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension

def get_ext_modules(argv):
BUILD_CUDA = FORCE_CUDA_EXTENSION or (torch.cuda.is_available() and (CUDA_HOME is not None))
except (ImportError, TypeError, AssertionError, AttributeError) as e:
warnings.warn(f"building custom extensions skipped: {e}")

compile_cuda_extension = "--cuda_ext" in argv
if "--cuda_ext" in argv:
argv.remove("--cuda_ext")
def get_ext_modules():

print(compile_cuda_extension)
ext_modules = []
cmdclass = {}

ext_modules = [
cpp_extension.CppExtension("disco_helpers", ["torch_harmonics/csrc/disco/disco_helpers.cpp"]),
]
if BUILD_CPP:
print(f"Compiling helper routines for torch-harmonics.")
ext_modules.append(CppExtension("disco_helpers", ["torch_harmonics/csrc/disco/disco_helpers.cpp"]))
cmdclass["build_ext"] = BuildExtension

if torch.cuda.is_available() or compile_cuda_extension:
if BUILD_CUDA:
print(f"Compiling custom CUDA kernels for torch-harmonics.")
ext_modules.append(
cpp_extension.CUDAExtension(
CUDAExtension(
"disco_cuda_extension",
[
"torch_harmonics/csrc/disco/disco_interface.cu",
Expand All @@ -63,15 +75,16 @@ def get_ext_modules(argv):
],
)
)
cmdclass["build_ext"] = BuildExtension

return ext_modules
return ext_modules, cmdclass

if __name__ == "__main__":

ext_modules = get_ext_modules(sys.argv)
ext_modules, cmdclass = get_ext_modules()

setup(
packages=find_packages(),
ext_modules=ext_modules,
cmdclass={"build_ext": cpp_extension.BuildExtension},
cmdclass=cmdclass,
)

0 comments on commit 7b20621

Please sign in to comment.