Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA Illegal memory access for certain input sizes to Whisper #2767

Open
2 of 4 tasks
MahmoudAshraf97 opened this issue Feb 9, 2025 · 2 comments
Open
2 of 4 tasks
Labels
bug Something isn't working

Comments

@MahmoudAshraf97
Copy link
Contributor

MahmoudAshraf97 commented Feb 9, 2025

System Info

-TRT-LLM 0.17.0.post1
-H100 PCIe

Who can help?

@byshiue

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Build Whisper Large-V3 engine using the official example
  2. Use the following python inference code
import torch
import tensorrt_llm.bindings.executor as trtllm

executor = trtllm.Executor(
    encoder_model_path="large-v3/encoder",
    decoder_model_path="large-v3/decoder",
    model_type=trtllm.ModelType.ENCODER_DECODER,
    executor_config=trtllm.ExecutorConfig(
        3, # to reproduce the issue, beam size must be 3 or larger
        max_batch_size=96,
        kv_cache_config=trtllm.KvCacheConfig(
            free_gpu_memory_fraction=0.5, cross_kv_cache_fraction=0.3
        ),
    ),
)

# the issue is reproducible with these input sizes: [2798, 2838, 2926, 2966]
features = torch.rand((128,2966)).half()

request = trtllm.Request(
    input_token_ids=[50258],
    max_tokens=200,
    encoder_input_features=features.T.half().contiguous(),  # mel features
    encoder_output_length=features.shape[-1] // 2,
    end_id=50257,
    pad_id=50257,
    sampling_config=trtllm.SamplingConfig(
        beam_width=3,
    ),
)

response = executor.await_responses(executor.enqueue_request(request))[0]

Expected behavior

to return a response without errors

actual behavior

[TensorRT-LLM][ERROR] Encountered an error in forwardSync function: [TensorRT-LLM][ERROR] CUDA runtime error in ::cudaEventSynchronize(get()): an illegal memory access was encountered (/home/jenkins/agent/workspace/LLM/release-0.17/L0_Test-x86_64/tensorrt_llm/cpp/include/tensorrt_llm/runtime/cudaEvent.h:66)
1       0x7fa8dc29fc4c void tensorrt_llm::common::check<cudaError>(cudaError, char const*, char const*, int) + 92
2       0x7fa8dc481d8b tensorrt_llm::runtime::GptDecoderBatched::forwardSync(tensorrt_llm::runtime::decoder_batch::DecoderFinishedEvent const&) + 59
3       0x7fa8dcdc8749 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::decoderSync(tensorrt_llm::batch_manager::ScheduledRequests const&, std::unique_ptr<tensorrt_llm::runtime::decoder_batch::DecoderFinishedEvent const, std::default_delete<tensorrt_llm::runtime::decoder_batch::DecoderFinishedEvent const> > const&) + 617
4       0x7fa8dcdc8ac4 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::forwardSync() + 596
5       0x7fa8dce59d26 tensorrt_llm::executor::Executor::Impl::forwardSync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > >&) + 54
6       0x7fa8dce659b7 tensorrt_llm::executor::Executor::Impl::executionLoop() + 439
7       0x7fadccd9b5c0 /root/mahmoud/mahmoud/lib/python3.10/site-packages/torch/lib/libtorch.so(+0x145c0) [0x7fadccd9b5c0]
8       0x7fadcb5c91c4 /lib/x86_64-linux-gnu/libc.so.6(+0x891c4) [0x7fadcb5c91c4]
9       0x7fadcb64985c /lib/x86_64-linux-gnu/libc.so.6(+0x10985c) [0x7fadcb64985c]

additional notes

The issue happens with Large-V3 and Large-V3 turbo but not other models, note that difference is that other models use an input of shape (80, seq_len) while the problematic models use (128, seq_len), it also does not happen with beam size less than 3, it happens when seq_len is one of these values only: [2798, 2838, 2926, 2966]

@MahmoudAshraf97 MahmoudAshraf97 added the bug Something isn't working label Feb 9, 2025
@yuekaizhang
Copy link

@MahmoudAshraf97 I was wondering if it is a bug only for 0.17.0.post. Also, for pretrained whisper model, we should send audios with 3000 seq_len. Are you using a fine-tuned whisper?

@MahmoudAshraf97
Copy link
Contributor Author

@yuekaizhang I'm using a finetuned model, but the issue is reproducible with large-v3
the issue still persists in 0.18.0.dev2025020400 and once I get the mentioned error above, any request that follows gets this error until I restart the process:

[TensorRT-LLM][ERROR] Encountered an error in forwardAsync function: [TensorRT-LLM][ERROR] CUDA runtime error in cudaMemcpyAsync(dst, src.data(), src.getSizeInBytes(), cudaMemcpyDefault, mStream->get()): an illegal memory access was encountered (/home/jenkins/agent/workspace/LLM/main/L0_Test-x86_64/tensorrt_llm/cpp/tensorrt_llm/runtime/bufferManager.cpp:156)
1       0x7f7b3cc2d65c void tensorrt_llm::common::check<cudaError>(cudaError, char const*, char const*, int) + 92
2       0x7f7b3d7b74f2 tensorrt_llm::batch_manager::EncoderBuffers::setFromInputs(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&, tensorrt_llm::runtime::TllmRuntime const&) + 994
3       0x7f7b3d7b8a4b tensorrt_llm::batch_manager::EncoderBuffers::prepareIO(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&, tensorrt_llm::runtime::TllmRuntime const&) + 75
4       0x7f7b3d837ac5 tensorrt_llm::batch_manager::TrtEncoderModel::executeBatch(tensorrt_llm::batch_manager::ScheduledRequests const&) + 85
5       0x7f7b3d83c872 tensorrt_llm::batch_manager::TrtEncoderModel::forwardAsync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&) + 1074
6       0x7f7b3d8dffe8 tensorrt_llm::executor::Executor::Impl::forwardAsync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > >&) + 136
7       0x7f7b3d8eb164 tensorrt_llm::executor::Executor::Impl::executionLoop() + 1252
8       0x7f802c0e15c0 /root/mahmoud/mahmoud/lib/python3.12/site-packages/torch/lib/libtorch.so(+0x145c0) [0x7f802c0e15c0]
9       0x7f802f7b01c4 /lib/x86_64-linux-gnu/libc.so.6(+0x891c4) [0x7f802f7b01c4]
10      0x7f802f83085c /lib/x86_64-linux-gnu/libc.so.6(+0x10985c) [0x7f802f83085c]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants