From 01819462fdc08d3e7cf4f9bedbb1216dd789604a Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Sun, 18 Aug 2024 19:26:04 -0400 Subject: [PATCH] Explicitly set register_float. --- lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m | 2 ++ .../mps/ccv_nnc_scaled_dot_product_attention_mps.m | 1 + 2 files changed, 3 insertions(+) diff --git a/lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m b/lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m index eee502a51..0d1330c04 100644 --- a/lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m +++ b/lib/nnc/cmd/convolution/mps/ccv_nnc_conv_mps.m @@ -260,6 +260,7 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .B_trans = 1, .D_trans = 0, .fused_bias = (bias ? 1 : 0), + .register_float = 0, .batch_dimension = b_batch_size, .batch_stride_a = a_batch_size > 1 ? H * W * I_dim : 0, @@ -277,6 +278,7 @@ static int _ccv_nnc_conv_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint .B_trans = 0, .D_trans = 1, .fused_bias = (bias ? 1 : 0), + .register_float = 0, .batch_dimension = b_batch_size, .batch_stride_a = w_batch_size > 1 ? O * I_dim : 0, diff --git a/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m b/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m index 44f2d09b5..6d38f0aad 100644 --- a/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m +++ b/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m @@ -317,6 +317,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c .B_trans = true, .D_trans = false, .fused_bias = (bias ? 1 : 0), + .register_float = 0, .batch_dimension = 1, .batch_stride_a = 0,