Skip to content

Commit

Permalink
[cker] Fix computing MSE Gradient
Browse files Browse the repository at this point in the history
This commit fixes MSE gradient to compute MSE Gradient for each batch.

ONE-DCO-1.0-Signed-off-by: ragmani <ragmani0216@gmail.com>
  • Loading branch information
ragmani committed Sep 6, 2024
1 parent 67df4c5 commit ded8b14
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
12 changes: 9 additions & 3 deletions compute/cker/include/cker/train/operation/Loss.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,16 @@ inline void MSEGrad(const Shape &y_pred_shape, const T *y_pred_data, const Shape
if (y_pred_shape != grad_shape)
throw std::runtime_error("cker::MSEGrad: y_pred_shape != grad_shape");

const int size = grad_shape.FlatSize();
for (int i = 0; i < size; ++i)
const int batch_size = grad_shape.Dims(0);
const int size = FlatSizeSkipDim(grad_shape, 0);
for (int b = 0; b < batch_size; ++b)
{
grad_data[i] = static_cast<T>(-2 * (y_true_data[i] - y_pred_data[i]) / size);
for (int i = 0; i < size; ++i)
{
const int offset = b * size + i;
assert(offset >= 0);
grad_data[offset] = static_cast<T>(-2 * (y_true_data[offset] - y_pred_data[offset]) / size);
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion compute/cker/src/train/Loss.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ TEST(CKer_Operation, LossMSEGrad)
std::vector<float> y_pred = {27.2, 31.8, 51.9, 10.2, 34.2, 12.4};
std::vector<float> y_true = {31.3, 40.3, 29.7, 12.9, 25.8, 11.9};
std::vector<float> deriv_y_pred(6);
std::vector<float> expected = {-1.3666667, -2.8333333, 7.4, -0.9, 2.8, 0.1666667};
std::vector<float> expected = {-2.7333324, -5.6666665, 14.8, -1.8, 5.6, 0.33333334};

nnfw::cker::train::MSEGrad(nnfw::cker::Shape{2, 3}, y_pred.data(), nnfw::cker::Shape{2, 3},
y_true.data(), nnfw::cker::Shape{2, 3}, deriv_y_pred.data());
Expand Down

0 comments on commit ded8b14

Please sign in to comment.