From 6e0f757bcdaafae3813c33c4ad99402c2cc5322c Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Mon, 6 Mar 2023 14:08:07 +0100 Subject: [PATCH] Use FP32 compute type for FP16 convolutions (#1115) --- src/ops/conv1d_gpu.cu | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/ops/conv1d_gpu.cu b/src/ops/conv1d_gpu.cu index bc3ba7598..e04cc6eb9 100644 --- a/src/ops/conv1d_gpu.cu +++ b/src/ops/conv1d_gpu.cu @@ -50,7 +50,11 @@ namespace ctranslate2 { /*stride_h=*/1, /*stride_w=*/_stride, /*dilation_h=*/1, /*dilation_w=*/_dilation, CUDNN_CROSS_CORRELATION, - data_type)); + CUDNN_DATA_FLOAT)); + + CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH)); + if (data_type == CUDNN_DATA_HALF) + CUDNN_CHECK(cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH)); cudnnHandle_t handle = cuda::get_cudnn_handle();