-
Notifications
You must be signed in to change notification settings - Fork 32
/
setup.py
103 lines (87 loc) · 3.61 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
from pathlib import Path
import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
cur_path = Path(__file__).parent
def get_version():
with open(cur_path / "vptq/__init__.py") as f:
for line in f:
if "__version__" in line:
return line.split("=")[-1].strip().strip('"')
return "0.0.1"
def build_cuda_extensions():
compute_capabilities = [70, 75, 80, 86, 89, 90]
arch_flags = []
TORCH_CUDA_ARCH_LIST = os.getenv("TORCH_CUDA_ARCH_LIST", None)
if TORCH_CUDA_ARCH_LIST is None:
print("TORCH_CUDA_ARCH_LIST is not set, compiling for all arch")
else:
delimiter = ' ' if ';' not in TORCH_CUDA_ARCH_LIST else ' '
TORCH_CUDA_ARCH_LIST = TORCH_CUDA_ARCH_LIST.split(delimiter)
compute_capabilities = [int(10 * float(arch)) for arch in TORCH_CUDA_ARCH_LIST if '+' not in arch]
if torch.cuda.is_available() and torch.version.hip is not None:
PYTORCH_ROCM_ARCH = os.getenv("PYTORCH_ROCM_ARCH", None)
arch_name = torch.cuda.get_device_properties().gcnArchName.split(":")[0]
if PYTORCH_ROCM_ARCH is not None and arch_name not in PYTORCH_ROCM_ARCH:
PYTORCH_ROCM_ARCH = PYTORCH_ROCM_ARCH + f";{arch_name}"
elif PYTORCH_ROCM_ARCH is None:
PYTORCH_ROCM_ARCH = arch_name
compute_capabilities = PYTORCH_ROCM_ARCH
os.environ["PYTORCH_ROCM_ARCH"] = PYTORCH_ROCM_ARCH
else:
for cap in compute_capabilities:
arch_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]
print(" build for compute capabilities: ==============", compute_capabilities)
# set nvcc threads
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
extra_compile_args = {
"nvcc": [
"-O3",
"-std=c++17",
"-DENABLE_BF16",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
f"--threads={nvcc_threads}",
] + arch_flags,
"cxx": ["-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"],
}
if torch.cuda.is_available() and torch.version.hip is not None:
extra_compile_args["nvcc"].extend(["-fbracket-depth=1024"])
else:
extra_compile_args["nvcc"].extend([
"--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", "-lineinfo"
])
extensions = CUDAExtension(
"vptq.ops",
[
"csrc/ops.cc",
"csrc/dequant_impl_packed.cu",
],
extra_compile_args=extra_compile_args,
)
return [extensions]
def get_requirements():
"""Get Python package dependencies from requirements.txt."""
with open(cur_path / "requirements.txt") as f:
requirements = f.read().strip().split("\n")
requirements = [req for req in requirements if "https" not in req]
return requirements
setup(
name="vptq",
python_requires=">=3.8",
packages=find_packages(exclude=[""]),
install_requires=get_requirements(),
version=get_version(),
ext_modules=build_cuda_extensions(),
cmdclass={"build_ext": BuildExtension},
)