Skip to content

Commit

Permalink
whisper : add CUDA-specific computation mel spectrograms (#2206)
Browse files Browse the repository at this point in the history
* whisper : use polymorphic class to calculate mel spectrogram

* whisper : add cuda-specific mel spectrogram calculation

* whisper : conditionally compile cufftGetErrorString to avoid warnings

* build : add new files to makefile

* ruby : add new files to conf script

* build : fix typo in makefile

* whisper : suppress cub warning for deprecated C++ std in whisper-mel-cuda
  • Loading branch information
iboB authored Jun 4, 2024
1 parent af5833e commit ffef323
Show file tree
Hide file tree
Showing 8 changed files with 497 additions and 99 deletions.
10 changes: 7 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,12 @@ if (WHISPER_CUDA)
if (WHISPER_STATIC)
if (WIN32)
# As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt CUDA::cufft)
else ()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static CUDA::cufft_static)
endif()
else()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cufft)
endif()

set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cuda_driver)
Expand Down Expand Up @@ -679,6 +679,10 @@ add_library(${TARGET}
whisper.cpp
)

if (WHISPER_CUDA)
target_sources(${TARGET} PRIVATE whisper-mel-cuda.cu)
endif()

include_directories (
.
)
Expand Down
9 changes: 6 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,8 @@ ifdef WHISPER_CUDA

CFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
CXXFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
WHISPER_OBJ += ggml-cuda.o
LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lcufft -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
WHISPER_OBJ += ggml-cuda.o whisper-mel-cuda.o
WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG)
Expand All @@ -299,6 +299,9 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif

whisper-mel-cuda.o: whisper-mel-cuda.cu whisper.h ggml.h ggml-backend.h whisper-mel.hpp whisper-mel-cuda.hpp
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@

ifdef WHISPER_HIPBLAS
ROCM_PATH ?= /opt/rocm
HIPCC ?= $(ROCM_PATH)/bin/hipcc
Expand Down Expand Up @@ -404,7 +407,7 @@ ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h

WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o

whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
whisper.o: whisper.cpp whisper.h whisper-mel.hpp ggml.h ggml-cuda.h
$(CXX) $(CXXFLAGS) -c $< -o $@

ifndef WHISPER_COREML
Expand Down
1 change: 1 addition & 0 deletions bindings/ruby/ext/extconf.rb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
require 'mkmf'
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper-mel.hpp')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-impl.h')} .")
Expand Down
Loading

0 comments on commit ffef323

Please sign in to comment.