Skip to content

Commit

Permalink
[WIP][DRAFT][onert] Optimized BatchMatMul in CPU backend
Browse files Browse the repository at this point in the history
This commit introduces improved BMM kernel for CPU.

ONE-DCO-1.0-Signed-off-by: Jan Iwaszkiewicz <j.iwaszkiewi@samsung.com>
  • Loading branch information
jiwaszki committed Sep 3, 2024
1 parent fa0a1ca commit b801d40
Showing 1 changed file with 41 additions and 14 deletions.
55 changes: 41 additions & 14 deletions compute/cker/include/cker/operation/reference/BatchMatMul.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "cker/Types.h"
#include "cker/Shape.h"
#include "cker/operation/optimized/Gemm.h"

namespace nnfw
{
Expand Down Expand Up @@ -73,7 +74,7 @@ inline void BatchMatMul(const Shape &lhs_shape, const float *lhs_data, const Sha
// Set params for each matrix multiply.
const int lhs_rows = extended_lhs_shape.Dims(3);
const int rhs_cols = extended_rhs_shape.Dims(4);
const int accum_depth = extended_lhs_shape.Dims(4);
// const int accum_depth = extended_lhs_shape.Dims(4);

for (int b0 = 0; b0 < batch_dim0; ++b0)
{
Expand All @@ -89,19 +90,45 @@ inline void BatchMatMul(const Shape &lhs_shape, const float *lhs_data, const Sha
const float *rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
float *out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) + b1 * batch_dim2 + b2) *
lhs_rows * rhs_cols;
for (int j = 0; j < rhs_cols; ++j)
{
for (int i = 0; i < lhs_rows; ++i)
{
float total = 0.f;
for (int k = 0; k < accum_depth; ++k)
{
total += lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k];
}
int idx = lhs_rows * j + i;
out_ptr[idx] = total;
}
}

MatrixParams<float> rhs_params;
rhs_params.order = Order::kColMajor; // should be always like this? base it of adj_x & adj_y?
rhs_params.rows = rhs_cols;
rhs_params.cols = lhs_rows;
// How to determine this?
rhs_params.cache_policy = nnfw::cker::optimized::DefaultCachePolicy(false);

MatrixParams<float> lhs_params;
lhs_params.order = Order::kRowMajor; // should be always like this? base it of adj_x & adj_y?
lhs_params.rows = lhs_rows;
lhs_params.cols = rhs_cols;
// How to determine this?
lhs_params.cache_policy = nnfw::cker::optimized::DefaultCachePolicy(false);

MatrixParams<float> dst_params;
dst_params.order = Order::kColMajor;
dst_params.rows = lhs_rows;
dst_params.cols = rhs_cols;

GemmParams<float, float> gemm_params;

nnfw::cker::optimized::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2, dst_params, out_ptr,
gemm_params);

// for (int j = 0; j < rhs_cols; ++j)
// {
// for (int i = 0; i < lhs_rows; ++i)
// {
// float total = 0.f;
// for (int k = 0; k < accum_depth; ++k)
// {
// total += lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k];
// }
// int idx = lhs_rows * j + i;
// out_ptr[idx] = total;
// }
// }

}
}
}
Expand Down

0 comments on commit b801d40

Please sign in to comment.