Skip to content

Commit

Permalink
Fix precision issue that cause nan for adam shader.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Mar 22, 2024
1 parent b13d5c2 commit 8503fea
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
6 changes: 3 additions & 3 deletions lib/nnc/ccv_cnnp_model.c
Original file line number Diff line number Diff line change
Expand Up @@ -1108,10 +1108,10 @@ static void _ccv_cnnp_model_gradient_init(ccv_cnnp_model_t* const model, const i
const ccv_nnc_graph_exec_symbol_t* destinations = ccv_nnc_symbolic_graph_destinations(model->graph);
const int destination_count = ccv_nnc_symbolic_graph_destination_size(model->graph);
int flag = 0;
const int outgrad_destination_start = destination_count - i;
assert(outgrad_destination_start >= 0);
const int outgrad_destination_start = ccv_max(0, destination_count - i);
for (j = i - 1; !flag && j >= 0; j--)
flag = (destinations[j + outgrad_destination_start].d == outgrad.d);
if (j + outgrad_destination_start < destination_count)
flag = (destinations[j + outgrad_destination_start].d == outgrad.d);
if (!flag) // Only if we cannot find it, we add it.
ccv_nnc_symbolic_graph_add_destination(model->graph, outgrad);
}
Expand Down
2 changes: 1 addition & 1 deletion lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ static int _ccv_nnc_gemm_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint

ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();
const int is_mfa_supported =
ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && is_same_batch && !bias && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION);
ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && is_same_batch && !bias && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && !(ccv_nnc_flags() & CCV_NNC_DISABLE_MFA_GEMM);

size_t a_data_size = 0;
if (a && dw && CCV_GET_DATA_TYPE(a->info.datatype) == CCV_QX)
Expand Down
20 changes: 14 additions & 6 deletions lib/nnc/mfa/ccv_nnc_mfa_adam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
using namespace ccv::nnc;

#include <string>
#include <sstream>
#include <iomanip>

// MARK: - C

Expand Down Expand Up @@ -137,6 +139,12 @@ std::size_t std::hash<mfa::adam::hash>::operator()(const mfa::adam::hash& hash)
return seed;
}

static std::string high_precision_to_string(float value) {
std::ostringstream oss;
oss << std::setprecision(std::numeric_limits<float>::max_digits10) << value;
return oss.str();
}

mfa::adam::pipeline::pipeline(mfa::context* context, mfa::adam::hash hash) {
// FlashNorm not supported for group adam yet.
CCV_NNC_MFA_PRECONDITION((hash.data_type == MTL::DataTypeFloat) || (hash.data_type == MTL::DataTypeHalf))
Expand Down Expand Up @@ -309,24 +317,24 @@ kernel void adam(
if (hash.adamw)
{
defines += "constant float rate_decay = ";
defines += std::to_string(hash.rate * hash.decay) + ";";
defines += high_precision_to_string(hash.rate * hash.decay) + ";";
defines += "\n";
} else {
defines += "constant float decay = ";
defines += std::to_string(hash.decay) + ";";
defines += high_precision_to_string(hash.decay) + ";";
defines += "\n";
}
defines += "constant float scale = ";
defines += std::to_string(hash.scale) + ";";
defines += high_precision_to_string(hash.scale) + ";";
defines += "\n";
defines += "constant float beta1 = ";
defines += std::to_string(hash.beta1) + ";";
defines += high_precision_to_string(hash.beta1) + ";";
defines += "\n";
defines += "constant float beta2 = ";
defines += std::to_string(hash.beta2) + ";";
defines += high_precision_to_string(hash.beta2) + ";";
defines += "\n";
defines += "constant float epsilon = ";
defines += std::to_string(hash.epsilon) + ";";
defines += high_precision_to_string(hash.epsilon) + ";";
defines += "\n";
this->group_size = MTL::Size(threadgroup_size, 1, 1);

Expand Down

0 comments on commit 8503fea

Please sign in to comment.