diff --git a/onert-micro/onert-micro/include/pal/common/PALSoftmaxInputGrad.h b/onert-micro/onert-micro/include/pal/common/PALSoftmaxInputGrad.h index 172a19c770d..36a0b86f10a 100644 --- a/onert-micro/onert-micro/include/pal/common/PALSoftmaxInputGrad.h +++ b/onert-micro/onert-micro/include/pal/common/PALSoftmaxInputGrad.h @@ -37,20 +37,25 @@ void inline SoftmaxInputGrad(const float *dloss_doutput_data, { assert(dloss_doutput_shape.dimensionsCount() == 2); assert(dloss_doutput_shape.dims(0) == 1); - const uint32_t output_dim = dloss_doutput_shape.dims(dloss_doutput_shape.dimensionsCount() - 1); - for (int i = 0; i < output_dim; ++i) + const uint32_t width = dloss_doutput_shape.dims(dloss_doutput_shape.dimensionsCount() - 1); + for (int w1 = 0; w1 < width; ++w1) { - for (int j = 0; j < output_dim; ++j) + float sum = 0.0f; + for (int w2 = 0; w2 < width; ++w2) { - jacobian_row_data[j] = -calculated_data[i] * calculated_data[j]; + float val; + if (w1 == w2) + { + val = calculated_data[w2] * (1.f - calculated_data[w2]); + } + else + { + val = -calculated_data[w2] * calculated_data[w1]; + } + val *= dloss_doutput_data[w2]; + sum += val; } - jacobian_row_data[i] += calculated_data[i]; - float total = 0.f; - for (int j = 0; j < output_dim; ++j) - { - total += jacobian_row_data[j] * dloss_doutput_data[j]; - } - dloss_dinput_data[i] = total; + dloss_dinput_data[w1] = sum; } }