Skip to content

Commit

Permalink
Add sm_89 and point to nvcuda.dll (#731)
Browse files Browse the repository at this point in the history
  • Loading branch information
powderluv authored Dec 26, 2022
1 parent ecfdec1 commit cc6fbdb
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions shark/iree_utils/gpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def get_iree_gpu_args():
# TODO: Give the user_interface to pass the sm_arch.
sm_arch = get_cuda_sm_cc()
if (
sm_arch in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86"]
sm_arch
in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86", "sm_89"]
) and (shark_args.enable_tf32 == True):
return [
"--iree-hal-cuda-disable-loop-nounroll-wa",
Expand Down Expand Up @@ -56,7 +57,7 @@ def get_iree_rocm_args():


def get_cuda_sm_cc():
libnames = ("libcuda.so", "libcuda.dylib", "cuda.dll")
libnames = ("libcuda.so", "libcuda.dylib", "nvcuda.dll")
for libname in libnames:
try:
cuda = ctypes.CDLL(libname)
Expand Down

0 comments on commit cc6fbdb

Please sign in to comment.