Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Discrepancy Between NVIDIA Triton and Direct Faster-Whisper Inference #8016

Open
YuBeomGon opened this issue Feb 18, 2025 · 0 comments

Comments

@YuBeomGon
Copy link

YuBeomGon commented Feb 18, 2025

Description:
I have been testing Faster-Whisper with NVIDIA Triton Inference Server and noticed a significant performance discrepancy compared to running the model directly in Python.

  1. Direct Python Inference:

    Running the model using the following code:

    model = WhisperModel("large-v3", device="cuda", compute_type="int8")

    Processing a single file takes approximately 0.1 seconds.

  2. Inference via NVIDIA Triton(localhost):

    Serving the same model on Triton and sending audio files via HTTP.
    Processing the same file takes approximately 0.2 seconds.

Observations:
Since Triton receives files over HTTP, I suspected that there might be idle periods where the GPU is not fully utilized.
However, monitoring GPU usage with nvidia-smi and gpustat shows a consistent GPU core utilization of ~97%, without noticeable idle gaps.

Question:
Why does inference take twice as long when using NVIDIA Triton compared to direct inference in Python? Is there an inherent overhead in Triton that causes this delay, even though GPU utilization appears to be consistently high?

setup
nvidia RTX 4080TI,
NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4

below is the full code.

Direct Python Inference


import asyncio
import time
import os
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from faster_whisper import WhisperModel
import librosa
import random

# 모델 초기화
model = WhisperModel("large-v3", device="cuda", compute_type="int8")

# 테스트할 오디오 파일 리스트 (100개 생성)
audio_files = ["~/Downloads/audio/after.wav"] * 100

# 음성 인식 함수 (싱글 스레드 실행)
def transcribe_audio(audio_file):
    if not os.path.exists(audio_file):
        return {"file": audio_file, "error": "File not found"}
    
    start_time = time.time()
    try:
        segments, info = model.transcribe(audio_file, without_timestamps=True)
        language = info.language if info.language else "Unknown"
    except Exception as e:
        return {"file": audio_file, "error": str(e)}
    
    elapsed_time = time.time() - start_time
    return {"file": audio_file, "time": elapsed_time, "language": language}

# 비동기 실행 함수 (100개 동시에 실행)
async def transcribe_all():
    loop = asyncio.get_running_loop()
    executor = ThreadPoolExecutor(max_workers=10)  # 동시에 10개씩 실행

    start_time = time.time()

    # 100개 오디오 파일을 동시에 실행
    tasks = [loop.run_in_executor(executor, transcribe_audio, audio_file) for audio_file in audio_files]
    results = await asyncio.gather(*tasks)

    end_time = time.time()
    total_time = end_time - start_time  # 전체 실행 시간
    avg_time = sum(r["time"] for r in results if "time" in r) / len([r for r in results if "time" in r])

    # 결과 출력
    print("\n=== 🕒 Whisper Batch Inference Results ===")
    print(f"Total execution time: {total_time:.2f} seconds")
    print(f"Average transcription time per file: {avg_time:.4f} seconds")
    print(f"Total files processed: {len(results)}")
    print("\nSample results:")
    for result in results[:5]:  # 샘플 5개만 출력
        print(result)

# 실행
if __name__ == "__main__":
    asyncio.run(transcribe_all())

using Triton


triton tree structure is following

faster-whisper
--1
----model.py
--config.pbtxt
--Dockerfile
--client_async.py

---model.py---

from faster_whisper import WhisperModel
import triton_python_backend_utils as pb_utils
import numpy as np
import io
import os
import time
from concurrent.futures import ThreadPoolExecutor

class TritonPythonModel:
    def initialize(self, args):
        device_id = args["model_instance_device_id"] # 0 / 1 / 2 / 3 
        os.environ["CUDA_VISIBLE_DEVICES"] = device_id        
        # Whisper 모델 초기화
        # self.model_name = "Systran/faster-whisper-large-v3"
        self.model_name = "large-v3"
        self.compute_type = "int8"
        self.device = "cuda"
        try:
            self.model = WhisperModel(
                self.model_name,
                compute_type=self.compute_type,
                device=self.device,
                # local_files_only=True,
            )
        except Exception as e:
            raise

    def execute(self, requests):
        responses = []
        for request in requests:
            # 요청에서 오디오 데이터를 bytes로 읽음
            audio = request.inputs()[0].as_numpy()[0]
            print(audio)
            
            # FP32 또는 INT16 데이터를 처리
            if audio.dtype == np.float32:
                # audio = (audio * 32768).astype(np.int16)  # FP32 → INT16 변환
                pass
            elif audio.dtype == np.int16:
                audio = (audio / 32768.0).astype(np.float16)     
                
            # STT 추론 시작 시간
            start_time = time.time()                 

            # Whisper 모델로 음성 변환
            segments, _ = self.model.transcribe(
                audio,
                without_timestamps=True,
            )
            
            # STT 추론 종료 시간
            inference_time = time.time() - start_time
                        
            segments = [segment.text for segment in segments]
            text = " ".join(segments)

            # Triton 응답 데이터 생성 (추론 시간 추가)
            responses.append(
                pb_utils.InferenceResponse(
                    [
                        pb_utils.Tensor(
                            "transcribed_text",
                            np.array([text.encode("utf-8")], dtype=np.object_),
                        ),
                        pb_utils.Tensor(
                            "inference_time",
                            np.array([inference_time], dtype=np.float32),
                        ),
                    ],
                ),
            )
        return responses 

    def finalize(self):
        # 모델 정리
        self.model = None

---config.pbtxt---

name: "faster-whisper"
backend: "python"
max_batch_size: 0

input [
    {
        name: "audio"
        data_type: TYPE_INT16 # TYPE_STRING, TYPE_INT16 TYPE_FP32
        dims:  [1, -1]
    }
]

output [
    {
        name: "transcribed_text"
        data_type: TYPE_STRING
        dims: [ 1 ]
    },
    {
        name: "inference_time"
        data_type: TYPE_FP32
        dims: [ 1 ]
    }
]

instance_group [
   {
     count: 1
     kind: KIND_GPU
     gpus: [0]
   }
]

optimization {
  execution_accelerators {
    gpu_execution_accelerator: []
  }
}

dynamic_batching {
  max_queue_delay_microseconds: 0
}

---client_async.py ---

import time
import asyncio
import aiohttp
import numpy as np
import librosa
import json
import soundfile as sf

# Triton 서버 주소
TRITON_URL = "http://localhost:9000/v2/models/faster-whisper/infer"

# 오디오 파일 목록
audio_files = ["~/after.wav"]

# 오디오 파일을 로드하고 INT16 변환
def load_audio(file_path):
    audio, sr = sf.read(file_path, dtype="int16")  # INT16 형식으로 로드
    
    # 16kHz로 리샘플링이 필요하면 변환
    if sr != 16000:
        print(f"⚠️ Resampling {file_path}: {sr}Hz → 16000Hz")
        audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=16000)
        audio = (audio * 32768).astype(np.int16)  # FLOAT → INT16 변환

    return np.expand_dims(audio.astype(np.int16), axis=0)  # Triton 입력에 맞게 차원 확장

# 비동기 요청 함수
async def send_request(session, audio_data, index):
    start_time = time.time()

    # Triton 요청 데이터 포맷
    payload = {
        "inputs": [
            {
                "name": "audio",
                "shape": list(audio_data.shape),
                "datatype": "INT16",
                "data": audio_data.flatten().tolist(),
            }
        ]
    }

    # 요청 전송
    try:
        async with session.post(TRITON_URL, json=payload) as response:
            end_time = time.time()
            elapsed_time = end_time - start_time

            try:
                result = await response.json()
                if response.status == 200:
                    text = result["outputs"][0]["data"][0]
                    return elapsed_time, text
                else:
                    return elapsed_time, f"[Error {response.status}] {await response.text()}"
            except json.JSONDecodeError:
                return elapsed_time, "[Error] Failed to decode JSON response"

    except aiohttp.ClientError as e:
        return 0, f"[Client Error] {str(e)}"

# 비동기 실행 (각 파일당 100개씩 전송)
async def process_audio_file(file_path):
    async with aiohttp.ClientSession() as session:
        audio_data = load_audio(file_path)
        tasks = [send_request(session, audio_data, i) for i in range(100)]

        print(f"\n🚀 Sending 100 requests for {file_path}...\n")
        start_time = time.time()
        results = await asyncio.gather(*tasks)
        end_time = time.time()

        # 변환된 텍스트 리스트
        transcriptions = [r[1] for r in results]  # 올바르게 text 값 가져오기

        total_time = end_time - start_time  # ✅ `\` 제거하여 문법 오류 방지

        # 출력
        print(f"\n✅ Completed: {file_path}")
        print(f"⏳ **Total Inference Time for {file_path}: {total_time:.4f} sec**")
        print(f"📌 Sample Transcriptions: {transcriptions[:5]}")  # 샘플 5개 출력

# 실행 (각 파일별로 100개씩 전송)
async def main():
    for file_path in audio_files:
        await process_audio_file(file_path)

# 실행
asyncio.run(main())

----Dockerfile----

FROM nvcr.io/nvidia/tritonserver:24.11-py3

# 필수 패키지 설치
RUN apt-get update && apt-get install -y \
    git \
    python3-dev \
    build-essential \
    cmake \
    && rm -rf /var/lib/apt/lists/*

# 소스 빌드 준비 (상위 디렉토리에서 복사)
COPY faster-whisper /workspace/triton/faster-whisper
WORKDIR /workspace/faster-whisper

# faster-whisper 소스 빌드 및 설치
# RUN pip install -e .
RUN pip install faster-whisper

# Triton 실행 디렉토리로 이동
WORKDIR /opt/tritonserver

----docker build----
docker build -t tritonserver-with-faster-whisper:v.2.0 .

----docker run----
docker run --gpus=all --ipc=host --rm --net=host -v ~/.cache/huggingface:/root/.cache/huggingface tritonserver-with-faster-whisper:v.2.0 tritonserver --backend-config=python,execution-thread-count=1 --model-repository=/workspace/triton --log-verbose=2 --http-port=9000 --grpc-port=9001 --metrics-port=9002

----test----
activate conda env
python client_async.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

1 participant