From 97c4c1050edc46bc92b4fe3ca8a9a8aba017b81d Mon Sep 17 00:00:00 2001 From: Rohit Santhanam Date: Thu, 28 Apr 2022 05:45:49 +0000 Subject: [PATCH] Fix for triangular_solve BEF executable unit test failure on ROCm. --- backends/gpu/lib/kernels/blas_kernels.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/backends/gpu/lib/kernels/blas_kernels.cc b/backends/gpu/lib/kernels/blas_kernels.cc index 2a2a2c3d4be..7078c924487 100644 --- a/backends/gpu/lib/kernels/blas_kernels.cc +++ b/backends/gpu/lib/kernels/blas_kernels.cc @@ -222,7 +222,11 @@ static Error BlasTrsmBatch( const void** a_array = const_cast(b_array + batchCount); auto side_mode = wrapper::BlasSideMode::FromOpaqueValue(*sideMode); - int32_t a_num_elements = side_mode == CUBLAS_SIDE_LEFT ? m * m : n * n; + int32_t a_num_elements = 0; + if (platform == wrapper::Platform::CUDA) + a_num_elements = side_mode == CUBLAS_SIDE_LEFT ? m * m : n * n; + else + a_num_elements = side_mode == rocblas_side_left ? m * m : n * n; ptrdiff_t a_batch_stride_bytes = *data_type_size_bytes * a_num_elements; ptrdiff_t b_batch_stride_bytes = *data_type_size_bytes * m * n; const char* a_ptr = static_cast(A.pointer().raw(platform));