diff --git a/lib/nnc/ccv_cnnp_model.c b/lib/nnc/ccv_cnnp_model.c index 9160aafb6..ee506e039 100644 --- a/lib/nnc/ccv_cnnp_model.c +++ b/lib/nnc/ccv_cnnp_model.c @@ -1108,10 +1108,10 @@ static void _ccv_cnnp_model_gradient_init(ccv_cnnp_model_t* const model, const i const ccv_nnc_graph_exec_symbol_t* destinations = ccv_nnc_symbolic_graph_destinations(model->graph); const int destination_count = ccv_nnc_symbolic_graph_destination_size(model->graph); int flag = 0; - const int outgrad_destination_start = destination_count - i; - assert(outgrad_destination_start >= 0); + const int outgrad_destination_start = ccv_max(0, destination_count - i); for (j = i - 1; !flag && j >= 0; j--) - flag = (destinations[j + outgrad_destination_start].d == outgrad.d); + if (j + outgrad_destination_start < destination_count) + flag = (destinations[j + outgrad_destination_start].d == outgrad.d); if (!flag) // Only if we cannot find it, we add it. ccv_nnc_symbolic_graph_add_destination(model->graph, outgrad); } diff --git a/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m b/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m index 2735f1d74..120fa1023 100644 --- a/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m +++ b/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m @@ -704,7 +704,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context(); const int is_mfa_supported = - ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && is_same_batch && !bias && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION); + ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && is_same_batch && !bias && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM); size_t a_data_size = 0; if (a && dw && CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX) diff --git a/lib/nnc/mfa/ccv_nnc_mfa_adam.cpp b/lib/nnc/mfa/ccv_nnc_mfa_adam.cpp index 425057188..68af8fb3e 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_adam.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_adam.cpp @@ -4,6 +4,8 @@ using namespace ccv::nnc; #include +#include +#include // MARK: - C @@ -137,6 +139,12 @@ std::size_t std::hash::operator()(const mfa::adam::hash& hash) return seed; } +static std::string high_precision_to_string(float value) { + std::ostringstream oss; + oss << std::setprecision(std::numeric_limits::max_digits10) << value; + return oss.str(); +} + mfa::adam::pipeline::pipeline(mfa::context* context, mfa::adam::hash hash) { // FlashNorm not supported for group adam yet. CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf)) @@ -309,24 +317,24 @@ kernel void adam( if (hash.adamw) { defines += "constant float rate_decay = "; - defines += std::to_string(hash.rate * hash.decay) + ";"; + defines += high_precision_to_string(hash.rate * hash.decay) + ";"; defines += "\n"; } else { defines += "constant float decay = "; - defines += std::to_string(hash.decay) + ";"; + defines += high_precision_to_string(hash.decay) + ";"; defines += "\n"; } defines += "constant float scale = "; - defines += std::to_string(hash.scale) + ";"; + defines += high_precision_to_string(hash.scale) + ";"; defines += "\n"; defines += "constant float beta1 = "; - defines += std::to_string(hash.beta1) + ";"; + defines += high_precision_to_string(hash.beta1) + ";"; defines += "\n"; defines += "constant float beta2 = "; - defines += std::to_string(hash.beta2) + ";"; + defines += high_precision_to_string(hash.beta2) + ";"; defines += "\n"; defines += "constant float epsilon = "; - defines += std::to_string(hash.epsilon) + ";"; + defines += high_precision_to_string(hash.epsilon) + ";"; defines += "\n"; this->group_size = MTL::Size(threadgroup_size, 1, 1);