Skip to content

Commit

Permalink
kernel test for GRU
Browse files Browse the repository at this point in the history
  • Loading branch information
chunseoklee committed Oct 23, 2024
1 parent 7638ab8 commit 034e72e
Showing 1 changed file with 68 additions and 1 deletion.
69 changes: 68 additions & 1 deletion compiler/luci-interpreter/src/kernels/GRU.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,74 @@ class GRUTest : public ::testing::Test
std::unique_ptr<IMemoryManager> _memory_manager;
};

TEST_F(GRUTest, floatTest) { SUCCEED(); }
TEST_F(GRUTest, floatTest) {
Shape input_shape{2, 1, 2};
std::vector<float> input_data {0.98045033, 0.39546537, 0.5209594, 0.72873044};

Shape ref_output_shape{1, 1, 2};
std::vector<float> ref_output_data{0.22777566, -0.1976251};

Shape hidden_hidden_shape{6, 2};
std::vector<float> hidden_hidden_data {
0.8073279857635498,
-0.5218740105628967,
0.1166749969124794,
0.33110499382019043,
0.2770330011844635,
0.23767800629138947,
0.1293960064649582,
0.17175200581550598,
-0.15584999322891235,
0.8137810230255127,
-0.2667199969291687,
-0.23028500378131866 };
Shape hidden_input_shape{6, 2};
std::vector<float> hidden_input_data {
-0.1928129941225052,
-0.4582270085811615,
-0.17884500324726105,
-0.27543601393699646,
0.704787015914917,
0.1874309927225113,
-0.28071099519729614,
-0.40605801343917847,
-0.4156219959259033,
0.6752780079841614,
0.4272859990596771,
-0.24114100635051727
};

Shape state_shape{1, 2};
std::vector<float> state_data { 0.0, 0.0 };


Tensor input_tensor =
makeInputTensor<DataType::FLOAT32>(input_shape, input_data, _memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);

Tensor hidden_hidden_tensor =
makeInputTensor<DataType::FLOAT32>(hidden_hidden_shape, hidden_hidden_data, _memory_manager.get());

Tensor hidden_input_tensor =
makeInputTensor<DataType::FLOAT32>(hidden_input_shape, hidden_input_data, _memory_manager.get());

Tensor state_tensor =
makeInputTensor<DataType::FLOAT32>(state_shape, state_data, _memory_manager.get());


GRUParams params{};

GRU kernel(&input_tensor, &hidden_hidden_tensor, nullptr, &hidden_input_tensor, nullptr, &state_tensor, &output_tensor, params);
kernel.configure();
_memory_manager->allocate_memory(output_tensor);
kernel.execute();

EXPECT_THAT(extractTensorData<float>(output_tensor),
::testing::ElementsAreArray(ref_output_data));
//EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray({1, 1, 2}));


SUCCEED(); }

} // namespace
} // namespace kernels
Expand Down

0 comments on commit 034e72e

Please sign in to comment.