Skip to content

Commit

Permalink
Fix mse gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
ragmani committed Sep 6, 2024
1 parent 93b5d13 commit d233fa7
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions compute/cker/include/cker/train/operation/Loss.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,17 @@ inline void MSEGrad(const Shape &y_pred_shape, const T *y_pred_data, const Shape
if (y_pred_shape != grad_shape)
throw std::runtime_error("cker::MSEGrad: y_pred_shape != grad_shape");

const int size = grad_shape.FlatSize();
for (int i = 0; i < size; ++i)
// TODO Optimize
const int flat_size = grad_shape.FlatSize();
const int batch_size = grad_shape.Dims(0);
for (int b = 0; b < batch_size; ++b)
{
grad_data[i] = static_cast<T>(-2 * (y_true_data[i] - y_pred_data[i]) / size);
const int size = flat_size / batch_size;
for (int i = 0; i < size; ++i)
{
const int offset = b * size + i;
grad_data[offset] = static_cast<T>(-2 * (y_true_data[offset] - y_pred_data[offset]) / size);
}
}
}

Expand Down

0 comments on commit d233fa7

Please sign in to comment.