From d4fa1e1053aa6aa7e47f6e5e3fe4d4fd5c55c39d Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Tue, 6 Aug 2024 15:54:40 +0300 Subject: [PATCH 1/4] [luci/pass] Introduce FuseGRU Pass This PR introduces FuseGRUPass for fusing decomposed gru pattern into single CircleGRU. ONE-DCO-1.0-Signed-off-by: Artem Balyshev ONE-DCO-1.0-Signed-off-by: Chunseok Lee --- .../luci/pass/include/luci/CircleOptimizer.h | 1 + .../luci/pass/include/luci/Pass/FuseGRUPass.h | 39 + compiler/luci/pass/src/CircleOptimizer.cpp | 3 +- compiler/luci/pass/src/FuseGRUPass.cpp | 674 ++++++++++++++++++ compiler/luci/pass/src/FuseGRUPass.test.cpp | 418 +++++++++++ 5 files changed, 1134 insertions(+), 1 deletion(-) create mode 100644 compiler/luci/pass/include/luci/Pass/FuseGRUPass.h create mode 100644 compiler/luci/pass/src/FuseGRUPass.cpp create mode 100644 compiler/luci/pass/src/FuseGRUPass.test.cpp diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index ed7cbf611df..14323639f81 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -77,6 +77,7 @@ class CircleOptimizer final FuseActivationFunction, FusePRelu, FuseGelu, + FuseGRU, FuseRsqrt, FuseRmsNorm, FuseRoPE, diff --git a/compiler/luci/pass/include/luci/Pass/FuseGRUPass.h b/compiler/luci/pass/include/luci/Pass/FuseGRUPass.h new file mode 100644 index 00000000000..152dc427d95 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseGRUPass.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_GRU_PASS_H__ +#define __LUCI_FUSE_GRU_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fuse certain pattern of subgraph into CircleGRU + * + * For detailed subgraph pattern to be fused, please check its implementation. + */ +struct FuseGRUPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseGRUPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_GRU_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index ef6a2d86a4d..ea38e460393 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -52,6 +52,7 @@ #include "luci/Pass/FusePreActivationBatchNormPass.h" #include "luci/Pass/FusePReluPass.h" #include "luci/Pass/FuseGeluPass.h" +#include "luci/Pass/FuseGRUPass.h" #include "luci/Pass/FuseRsqrtPass.h" #include "luci/Pass/FuseSliceWithTConvPass.h" #include "luci/Pass/FuseHorizontalFullyConnectedPass.h" @@ -398,7 +399,7 @@ void CircleOptimizer::optimize(loco::Graph *g) const option_to_pass[Options::Algorithm::XpSepActFromTransposeConv] = &createPassInstance; option_to_pass[Options::Algorithm::ForwardReshapeToUnaryOp] = &createPassInstance; option_to_pass[Options::Algorithm::ForwardTransposeOp] = &createPassInstance; - // clang-format on + // clang-format on for (auto const &m : option_to_pass) { diff --git a/compiler/luci/pass/src/FuseGRUPass.cpp b/compiler/luci/pass/src/FuseGRUPass.cpp new file mode 100644 index 00000000000..2f1f2d341ef --- /dev/null +++ b/compiler/luci/pass/src/FuseGRUPass.cpp @@ -0,0 +1,674 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseGRUPass.h" +#include "helpers/NodeFiller.h" + +#include + +#include +#include + +#include + +#include + +// Helper to fuse GRU +namespace +{ + +class GRUPatternBase +{ +public: + GRUPatternBase(luci::CircleNode *candidate) { _pattern_last_node = candidate; } + + virtual ~GRUPatternBase() = default; + +public: + virtual bool matched() = 0; + +public: + luci::CircleNode *_ifm = nullptr; + luci::CircleConst *_weight_ih = nullptr; + luci::CircleConst *_bias_ih = nullptr; + luci::CircleConst *_weight_hh = nullptr; + luci::CircleConst *_bias_hh = nullptr; + + luci::CircleConst *_hidden_input = nullptr; + + luci::CircleConst *_less_const = nullptr; + + luci::CircleWhile *_while_node = nullptr; + luci::CircleWhileOut *_while_out_node = nullptr; + luci::CircleNode *_pattern_last_node = nullptr; +}; + +/** + * Below diagram shows GRU pattern to fuse. + * Note: this pattern for GRU with `return_sequences=False` + * - the below pattern will be replaced with one GRU + * Main Graph: + * [In] [CircleConst] [CircleConst] [CircleConst] [CircleConst] + * | | | | | + * V | | | | + * [CircleWhile]<----------------------------------------------------- + * | + * V + * [CircleWhileOut] + * | + * V + * [Out] + * + * Condition Graph: + * [In] [CircleConst] (scalar int32 value) + * | | + * V | + * [Less]------ + * | + * V + * [Out] + * + * Body Graph must contain: + * - 2 CircleFullyConnected nodes; + * - 3 CircleMul nodes; + * - 2 CircleLogistic nodes; + * - 2 CircleSplit nodes; + * - 6 CircleAdd nodes; + * - 1 CircleGather node; + * - 1 CircleReshape node; + * - 1 CircleSub node; + * - 1 CircleTanh node; + * - 6 CircleSplitOut nodes; + * - 5 CircleInput nodes; + * - 5 CircleOutput nodes; + * + * Body Graph: + * [In_1] [In_2]--->[Add_2 (with Const)]--->[Out_2] [In_3] + * | \ | | + * | \ [In_4]---[Gather] [Add_3 (with Const)] + * | [FullyConnected_1] | | | + * | | [Out_4] | [Out_3] + * | [Split_1] [FullyConnected_2] + * | / | \ | + * | | | \ [Split_2] + * | [Add_1] -------+----+---------------------------------/ | | + * | | | | | | + * | | | ------------------------------------[Add_4] | + * | | | | | + * | | | [Logistic_1] | + * | | | | | + * | | ----------------------------------------[Mul_2] | + * | | \ / + * | | [Add_5] + * | | | + * | [Logistic_2] [Tanh] + * \ / \ | + * [Mul_1] [Sub (with const)] | + * \ \ | + * \ ---------------------------[Mul_3] + * \ / + * \ / + * --------------------[Add_6]------------------------------ + * / \ + * / \ + * [Reshape] [Out_5] + * | + * [Out_1] + */ +class GRUPattern1 final : public GRUPatternBase +{ +public: + GRUPattern1(luci::CircleWhileOut *candidate) : GRUPatternBase(candidate) + { + assert(candidate); + _while_out_node = candidate; + } + +public: + bool matched() override; +}; + +bool GRUPattern1::matched() +{ + // 0 - check while node + _while_node = dynamic_cast(_while_out_node->input()); + if (_while_node == nullptr) + return false; + + // 1 - check condition graph: only one Less operation + // with scalar int const value + { + const auto cond_graph = _while_node->cond_graph(); + + const auto cond_nodes = loco::active_nodes(loco::output_nodes(cond_graph)); + if (cond_nodes.size() != 4) + return false; + luci::CircleLess *less_node = nullptr; + for (auto node : cond_nodes) + { + less_node = dynamic_cast(node); + if (less_node != nullptr) + break; + } + + // doesn't find Less node + if (less_node == nullptr) + return false; + + luci::CircleNode *less_input; + if (not luci::fill(&less_input, &_less_const).with_commutative_args_of(less_node)) + return false; + + if (_less_const->dtype() != loco::DataType::S32) + return false; + + if (_less_const->size() != 1) + return false; + + assert(_less_const->at(0) > 0); + } + + // 2 - Check while's input nodes + // Save hidden state input node + { + if (_while_node->input_count() != 5) + return false; + + // Save input node + _ifm = dynamic_cast(_while_node->input(4)); + if (_ifm == nullptr) + return false; + + _hidden_input = dynamic_cast(_while_node->input(3)); + if (_hidden_input == nullptr) + return false; + } + + // 3 - check body graph + { + const auto body_graph = _while_node->body_graph(); + + if (loco::input_nodes(body_graph).size() != 5) + return false; + + if (loco::output_nodes(body_graph).size() != 5) + return false; + + const auto body_nodes = loco::active_nodes(loco::output_nodes(body_graph)); + + // Save all nodes according its types + std::vector fc_nodes; + std::vector split_nodes; + std::vector logistic_nodes; + std::vector mul_nodes; + std::vector add_nodes; + std::vector sub_nodes; + std::vector reshape_nodes; + std::vector gather_nodes; + std::vector tanh_nodes; + std::vector split_out_nodes; + + for (auto node : body_nodes) + { + auto circle_node = dynamic_cast(node); + switch (circle_node->opcode()) + { + case luci::CircleOpcode::CIRCLECONST: + case luci::CircleOpcode::CIRCLEINPUT: + case luci::CircleOpcode::CIRCLEOUTPUT: + case luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE: + break; + case luci::CircleOpcode::FULLY_CONNECTED: + fc_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::SPLIT: + split_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::LOGISTIC: + logistic_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::MUL: + mul_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::ADD: + add_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::SUB: + sub_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::RESHAPE: + reshape_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::GATHER: + gather_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::TANH: + tanh_nodes.push_back(dynamic_cast(circle_node)); + break; + case luci::CircleOpcode::CIRCLESPLITOUT: + split_out_nodes.push_back(dynamic_cast(circle_node)); + break; + default: + return false; + } + } + + // Check number of nodes + if (fc_nodes.size() != 2 or mul_nodes.size() != 3 or logistic_nodes.size() != 2 or + split_nodes.size() != 2 or add_nodes.size() != 6 or gather_nodes.size() != 1 or + reshape_nodes.size() != 1 or sub_nodes.size() != 1 or tanh_nodes.size() != 1 or + split_out_nodes.size() != 6) + return false; + + // Check structure + // TODO: add more checks + { + // 1 - Check Split ops + // Both has FC nodes as input + // Axis is const + for (auto node : split_nodes) + { + if (dynamic_cast(node->split_dim()) == nullptr or + dynamic_cast(node->input()) == nullptr) + return false; + } + + // 2 - Check Logistic ops + // Add is input node for both nodes + for (auto node : logistic_nodes) + { + if (dynamic_cast(node->x()) == nullptr) + return false; + } + + // 3 - Check Sub + // Const - is first input node + // Logistic - is second input node + for (auto node : sub_nodes) + { + if (dynamic_cast(node->y()) == nullptr or + dynamic_cast(node->x()) == nullptr) + return false; + } + + // 4 - Check Add + // Mul or Const or Input or Split ops can be input nodes + // Mul - 3 times as input + // Const - 2 times as input + // Input - 2 times as input + // Split - 5 times as input + { + int num_mul = 0; + int num_const = 0; + int num_input = 0; + int num_split = 0; + for (auto node : add_nodes) + { + auto x_node = dynamic_cast(node->x()); + auto y_node = dynamic_cast(node->y()); + switch (x_node->opcode()) + { + case luci::CircleOpcode::CIRCLECONST: + num_const++; + break; + case luci::CircleOpcode::CIRCLEINPUT: + num_input++; + break; + case luci::CircleOpcode::CIRCLESPLITOUT: + num_split++; + break; + case luci::CircleOpcode::MUL: + num_mul++; + break; + default: + return false; + } + + switch (y_node->opcode()) + { + case luci::CircleOpcode::CIRCLECONST: + num_const++; + break; + case luci::CircleOpcode::CIRCLEINPUT: + num_input++; + break; + case luci::CircleOpcode::CIRCLESPLITOUT: + num_split++; + break; + case luci::CircleOpcode::MUL: + num_mul++; + break; + default: + return false; + } + } + if (num_mul != 3 or num_split != 5 or num_const != 2 or num_input != 2) + return false; + } + } + + // 5 - Check Mul + // Logistic or Tanh or Sub or Input or Split ops can be input nodes + // Logistic - 2 times as input + // Tanh - 1 times as input + // Sub - 1 times as input + // Split - 1 times as input + // Input - 1 times as input + { + int num_logistic = 0; + int num_tanh = 0; + int num_sub = 0; + int num_split = 0; + int num_input = 0; + for (auto node : mul_nodes) + { + auto x_node = dynamic_cast(node->x()); + auto y_node = dynamic_cast(node->y()); + switch (x_node->opcode()) + { + case luci::CircleOpcode::LOGISTIC: + num_logistic++; + break; + case luci::CircleOpcode::CIRCLEINPUT: + num_input++; + break; + case luci::CircleOpcode::CIRCLESPLITOUT: + num_split++; + break; + case luci::CircleOpcode::TANH: + num_tanh++; + break; + case luci::CircleOpcode::SUB: + num_sub++; + break; + default: + return false; + } + + switch (y_node->opcode()) + { + case luci::CircleOpcode::LOGISTIC: + num_logistic++; + break; + case luci::CircleOpcode::CIRCLEINPUT: + num_input++; + break; + case luci::CircleOpcode::CIRCLESPLITOUT: + num_split++; + break; + case luci::CircleOpcode::TANH: + num_tanh++; + break; + case luci::CircleOpcode::SUB: + num_sub++; + break; + default: + return false; + } + } + if (num_logistic != 2 or num_tanh != 1 or num_sub != 1 or num_split != 1 or num_input != 1) + return false; + } + + // 6 - Check Gather + // Gather has two CircleInput as input + { + for (auto node : gather_nodes) + { + if (dynamic_cast(node->indices()) == nullptr) + return false; + + if (dynamic_cast(node->params()) == nullptr) + return false; + } + } + + // 7 - Check Tanh + // Input is CircleAdd + { + for (auto node : tanh_nodes) + { + if (dynamic_cast(node->x()) == nullptr) + return false; + } + } + + // Find input and hidden FC weights and biases + for (auto node : body_nodes) + { + auto *fc_node = dynamic_cast(node); + if (fc_node == nullptr) + continue; + + const auto input_node = dynamic_cast(fc_node->input()); + if (input_node == nullptr) + return false; + + // For input hidden FullyConnected - input node is CircleInput node + if (dynamic_cast(input_node) != nullptr) + { + _weight_ih = dynamic_cast(fc_node->weights()); + _bias_ih = dynamic_cast(fc_node->bias()); + } + // For hidden hidden FullyConnected - input node is CircleGather node + else if (dynamic_cast(input_node) != nullptr) + { + _weight_hh = dynamic_cast(fc_node->weights()); + _bias_hh = dynamic_cast(fc_node->bias()); + } + else + { + return false; + } + } + + if (_weight_ih == nullptr or _weight_hh == nullptr) + return false; + } + + return true; +} + +class FuseGRU final +{ +public: + FuseGRU(const GRUPatternBase *p) : _p(p) {} + +public: + void apply(void); + +private: + luci::CircleGRU *create_circle_gru(loco::Graph *graph); + +private: + const GRUPatternBase *_p; +}; + +template +void copy_values(const luci::CircleConst *node, luci::CircleConst *cloned) +{ + assert(T == node->dtype()); + assert(T == cloned->dtype()); + + const auto size = node->size(); + cloned->size(size); + for (uint32_t i = 0; i < size; i++) + cloned->at(i) = node->at(i); +} + +luci::CircleConst *clone_circleconst(luci::CircleConst *node, loco::Graph *graph) +{ + auto cloned = graph->nodes()->create(); + + if (cloned != nullptr) + { + // dtype/shape + cloned->dtype(node->dtype()); + cloned->rank(node->rank()); + + // values + switch (node->dtype()) + { + case loco::DataType::FLOAT32: + copy_values(node, cloned); + break; + + case loco::DataType::U8: + copy_values(node, cloned); + break; + + case loco::DataType::S8: + copy_values(node, cloned); + break; + + case loco::DataType::S16: + copy_values(node, cloned); + break; + + case loco::DataType::S32: + copy_values(node, cloned); + break; + + case loco::DataType::S64: + copy_values(node, cloned); + break; + + case loco::DataType::BOOL: + copy_values(node, cloned); + break; + + default: + assert(false); + } + } + + return cloned; +} + +luci::CircleGRU *FuseGRU::create_circle_gru(loco::Graph *graph) +{ + assert(graph); + + auto weight_ih_cloned = clone_circleconst(_p->_weight_ih, graph); + luci::copy_common_attributes(_p->_weight_ih, weight_ih_cloned); + + auto weight_hh_cloned = clone_circleconst(_p->_weight_hh, graph); + luci::copy_common_attributes(_p->_weight_hh, weight_hh_cloned); + + luci::CircleNode *bias_ih_cloned = nullptr; + if (_p->_bias_ih != nullptr) + { + bias_ih_cloned = clone_circleconst(_p->_bias_ih, graph); + luci::copy_common_attributes(_p->_bias_ih, bias_ih_cloned); + } + else + { + bias_ih_cloned = _p->_pattern_last_node->graph()->nodes()->create(); + } + + luci::CircleNode *bias_hh_cloned = nullptr; + if (_p->_bias_hh != nullptr) + { + bias_hh_cloned = clone_circleconst(_p->_bias_hh, graph); + luci::copy_common_attributes(_p->_bias_hh, bias_hh_cloned); + } + else + { + bias_hh_cloned = _p->_pattern_last_node->graph()->nodes()->create(); + } + + auto hidden_input_cloned = clone_circleconst(_p->_hidden_input, graph); + luci::copy_common_attributes(_p->_hidden_input, hidden_input_cloned); + + auto less_const_cloned = clone_circleconst(_p->_less_const, graph); + luci::copy_common_attributes(_p->_less_const, less_const_cloned); + + // Create and configure new CircleGRU operation. + auto circle_gru = _p->_while_node->graph()->nodes()->create(); + circle_gru->input(_p->_ifm); + circle_gru->hidden_hidden(weight_hh_cloned); + circle_gru->hidden_input(weight_ih_cloned); + circle_gru->hidden_hidden_bias(bias_hh_cloned); + circle_gru->hidden_input_bias(bias_ih_cloned); + circle_gru->state(hidden_input_cloned); + + // Note: Now support only returnSequences = false + circle_gru->returnSequences(false); + circle_gru->name("FusedCircleGRU"); + + return circle_gru; +} + +void FuseGRU::apply() +{ + auto graph = _p->_pattern_last_node->graph(); + + auto gru_out = create_circle_gru(graph); + + // set origin + std::vector> origin_vec{ + luci::get_origin(_p->_while_node), luci::get_origin(_p->_while_out_node), + luci::get_origin(_p->_weight_hh), luci::get_origin(_p->_weight_ih)}; + + luci::add_origin(gru_out, luci::composite_origin(origin_vec)); + + replace(_p->_pattern_last_node).with(gru_out); +} + +} // namespace + +namespace +{ + +bool fuse_gru(luci::CircleWhileOut *while_out_node) +{ + assert(while_out_node); + + // check first pattern + GRUPattern1 pattern(while_out_node); + if (pattern.matched()) + { + FuseGRU fuse(&pattern); + fuse.apply(); + return true; + } + + return false; +} + +} // namespace + +namespace luci +{ + +bool FuseGRUPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto while_out_node = dynamic_cast(node); + if (not while_out_node) + continue; + + if (fuse_gru(while_out_node)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseGRUPass.test.cpp b/compiler/luci/pass/src/FuseGRUPass.test.cpp new file mode 100644 index 00000000000..93909ea673f --- /dev/null +++ b/compiler/luci/pass/src/FuseGRUPass.test.cpp @@ -0,0 +1,418 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseGRUPass.h" + +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class GRUGraphlet +{ +public: + GRUGraphlet() = default; + + void init(loco::Graph *g) + { + _while_node = g->nodes()->create(5, 5); + _while_out_node = g->nodes()->create(); + _hidden_node = g->nodes()->create(); + _hidden_node->dtype(loco::DataType::FLOAT32); + _time_node = g->nodes()->create(); + _time_node->dtype(loco::DataType::FLOAT32); + _state_node = g->nodes()->create(); + _state_node->dtype(loco::DataType::FLOAT32); + + _body_graph = loco::make_graph(); + _cond_graph = loco::make_graph(); + + _less_node = _cond_graph->nodes()->create(); + _less_const_node = _cond_graph->nodes()->create(); + _less_const_node->dtype(loco::DataType::S32); + _less_const_node->size(1); + _less_const_node->at(0) = 1; + + _add_node_1 = _body_graph->nodes()->create(); + _add_node_2 = _body_graph->nodes()->create(); + _add_node_3 = _body_graph->nodes()->create(); + _add_node_4 = _body_graph->nodes()->create(); + _add_node_5 = _body_graph->nodes()->create(); + _add_node_6 = _body_graph->nodes()->create(); + + _fc_node_1 = _body_graph->nodes()->create(); + _fc_node_2 = _body_graph->nodes()->create(); + _fc_weight_1 = _body_graph->nodes()->create(); + _fc_weight_1->dtype(loco::DataType::FLOAT32); + _fc_weight_2 = _body_graph->nodes()->create(); + _fc_weight_2->dtype(loco::DataType::FLOAT32); + _fc_bias_1 = _body_graph->nodes()->create(); + _fc_bias_1->dtype(loco::DataType::FLOAT32); + _fc_bias_2 = _body_graph->nodes()->create(); + _fc_bias_2->dtype(loco::DataType::FLOAT32); + + _split_const = _body_graph->nodes()->create(); + _split_const->dtype(loco::DataType::S32); + + _logistic_node_1 = _body_graph->nodes()->create(); + _logistic_node_2 = _body_graph->nodes()->create(); + + _gather_node = _body_graph->nodes()->create(); + + _mul_node_1 = _body_graph->nodes()->create(); + _mul_node_2 = _body_graph->nodes()->create(); + _mul_node_3 = _body_graph->nodes()->create(); + + _tanh_node = _body_graph->nodes()->create(); + _sub_node = _body_graph->nodes()->create(); + + _split_node_1 = _body_graph->nodes()->create(); + _split_node_2 = _body_graph->nodes()->create(); + _split_out_node_1 = _body_graph->nodes()->create(); + _split_out_node_2 = _body_graph->nodes()->create(); + _split_out_node_3 = _body_graph->nodes()->create(); + _split_out_node_4 = _body_graph->nodes()->create(); + _split_out_node_5 = _body_graph->nodes()->create(); + _split_out_node_6 = _body_graph->nodes()->create(); + + _reshape_node = _body_graph->nodes()->create(); + + auto graph_input_cond_graph = _cond_graph->inputs()->create(); + _cond_input_node = _cond_graph->nodes()->create(); + _cond_input_node->index(graph_input_cond_graph->index()); + + auto graph_output_cond_graph = _cond_graph->outputs()->create(); + _cond_output_node = _cond_graph->nodes()->create(); + _cond_output_node->index(graph_output_cond_graph->index()); + + auto graph_input_body_graph_1 = _body_graph->inputs()->create(); + _body_input_node_1 = _body_graph->nodes()->create(); + _body_input_node_1->index(graph_input_body_graph_1->index()); + + auto graph_input_body_graph_2 = _body_graph->inputs()->create(); + _body_input_node_2 = _body_graph->nodes()->create(); + _body_input_node_2->index(graph_input_body_graph_2->index()); + + auto graph_input_body_graph_3 = _body_graph->inputs()->create(); + _body_input_node_3 = _body_graph->nodes()->create(); + _body_input_node_3->index(graph_input_body_graph_3->index()); + + auto graph_input_body_graph_4 = _body_graph->inputs()->create(); + _body_input_node_4 = _body_graph->nodes()->create(); + _body_input_node_4->index(graph_input_body_graph_4->index()); + + auto graph_input_body_graph_5 = _body_graph->inputs()->create(); + _body_input_node_5 = _body_graph->nodes()->create(); + _body_input_node_5->index(graph_input_body_graph_5->index()); + + auto graph_output_body_graph_1 = _body_graph->outputs()->create(); + _body_output_node_1 = _body_graph->nodes()->create(); + _body_output_node_1->index(graph_output_body_graph_1->index()); + + auto graph_output_body_graph_2 = _body_graph->outputs()->create(); + _body_output_node_2 = _body_graph->nodes()->create(); + _body_output_node_2->index(graph_output_body_graph_2->index()); + + auto graph_output_body_graph_3 = _body_graph->outputs()->create(); + _body_output_node_3 = _body_graph->nodes()->create(); + _body_output_node_3->index(graph_output_body_graph_3->index()); + + auto graph_output_body_graph_4 = _body_graph->outputs()->create(); + _body_output_node_4 = _body_graph->nodes()->create(); + _body_output_node_4->index(graph_output_body_graph_4->index()); + + auto graph_output_body_graph_5 = _body_graph->outputs()->create(); + _body_output_node_5 = _body_graph->nodes()->create(); + _body_output_node_5->index(graph_output_body_graph_5->index()); + } + + void invalid_less_const_type() { _less_const_node->dtype(loco::DataType::S16); } + +protected: + luci::CircleWhile *_while_node; + luci::CircleWhileOut *_while_out_node; + luci::CircleConst *_time_node; + luci::CircleConst *_state_node; + luci::CircleConst *_hidden_node; + + luci::CircleInput *_cond_input_node; + luci::CircleLess *_less_node; + luci::CircleConst *_less_const_node; + luci::CircleOutput *_cond_output_node; + + luci::CircleInput *_body_input_node_1; + luci::CircleInput *_body_input_node_2; + luci::CircleInput *_body_input_node_3; + luci::CircleInput *_body_input_node_4; + luci::CircleInput *_body_input_node_5; + + luci::CircleOutput *_body_output_node_1; + luci::CircleOutput *_body_output_node_2; + luci::CircleOutput *_body_output_node_3; + luci::CircleOutput *_body_output_node_4; + luci::CircleOutput *_body_output_node_5; + + luci::CircleAdd *_add_node_1; + luci::CircleAdd *_add_node_2; + luci::CircleAdd *_add_node_3; + luci::CircleAdd *_add_node_4; + luci::CircleAdd *_add_node_5; + luci::CircleAdd *_add_node_6; + + luci::CircleMul *_mul_node_1; + luci::CircleMul *_mul_node_2; + luci::CircleMul *_mul_node_3; + + luci::CircleSub *_sub_node; + luci::CircleTanh *_tanh_node; + luci::CircleReshape *_reshape_node; + luci::CircleGather *_gather_node; + luci::CircleLogistic *_logistic_node_1; + luci::CircleLogistic *_logistic_node_2; + luci::CircleSplit *_split_node_1; + luci::CircleSplit *_split_node_2; + + luci::CircleSplitOut *_split_out_node_1; + luci::CircleSplitOut *_split_out_node_2; + luci::CircleSplitOut *_split_out_node_3; + luci::CircleSplitOut *_split_out_node_4; + luci::CircleSplitOut *_split_out_node_5; + luci::CircleSplitOut *_split_out_node_6; + + luci::CircleFullyConnected *_fc_node_1; + luci::CircleFullyConnected *_fc_node_2; + + luci::CircleConst *_split_const; + luci::CircleConst *_fc_weight_1; + luci::CircleConst *_fc_bias_1; + luci::CircleConst *_fc_weight_2; + luci::CircleConst *_fc_bias_2; + + std::unique_ptr _cond_graph; + std::unique_ptr _body_graph; +}; + +class FuseGRUTestGraph1 : public TestIOGraph, public GRUGraphlet +{ +public: + FuseGRUTestGraph1() = default; + + void init(void) + { + TestIOGraph::init({1}, {1}); + GRUGraphlet::init(g()); + + _while_node->input(0, _time_node); + _while_node->input(1, _time_node); + _while_node->input(2, _state_node); + _while_node->input(3, _hidden_node); + _while_node->input(4, input()); + + _while_out_node->input(_while_node); + output()->from(_while_out_node); + + _while_node->cond_graph(_cond_graph.get()); + _while_node->body_graph(_body_graph.get()); + + // cond graph + _less_node->x(_cond_input_node); + _less_node->y(_less_const_node); + _cond_output_node->from(_less_node); + + // body graph + _add_node_1->x(_body_input_node_1); + _add_node_1->y(_split_const); + _add_node_2->x(_body_input_node_2); + _add_node_2->y(_split_const); + + _body_output_node_5->from(_add_node_1); + _body_output_node_4->from(_add_node_2); + + _gather_node->params(_body_input_node_2); + _gather_node->indices(_body_input_node_1); + _fc_node_1->input(_body_input_node_4); + _fc_node_1->weights(_fc_weight_1); + _fc_node_1->bias(_fc_bias_1); + _fc_node_2->input(_gather_node); + _fc_node_2->weights(_fc_weight_2); + _fc_node_2->bias(_fc_bias_2); + + _split_node_1->input(_fc_node_1); + _split_node_1->split_dim(_split_const); + _split_node_2->input(_fc_node_2); + _split_node_2->split_dim(_split_const); + + _split_out_node_1->input(_split_node_1); + _split_out_node_2->input(_split_node_1); + _split_out_node_3->input(_split_node_1); + + _split_out_node_4->input(_split_node_2); + _split_out_node_5->input(_split_node_2); + _split_out_node_6->input(_split_node_2); + + _add_node_3->x(_split_out_node_1); + _add_node_3->y(_split_out_node_4); + + _add_node_4->x(_split_out_node_3); + _add_node_4->y(_split_out_node_6); + + _logistic_node_1->x(_add_node_3); + + _mul_node_1->x(_body_input_node_4); + _mul_node_1->y(_logistic_node_1); + + _sub_node->y(_logistic_node_1); + _sub_node->x(_split_const); + + _logistic_node_2->x(_add_node_4); + + _mul_node_2->x(_split_out_node_2); + _mul_node_2->y(_logistic_node_2); + + _add_node_5->x(_split_out_node_5); + _add_node_5->y(_mul_node_2); + + _tanh_node->x(_add_node_5); + + _mul_node_3->x(_sub_node); + _mul_node_3->y(_tanh_node); + + _add_node_6->x(_mul_node_1); + _add_node_6->y(_mul_node_3); + + _reshape_node->shape(_add_node_6); + + _body_output_node_3->from(_reshape_node); + } +}; + +class FuseGRUTestNegGraph : public TestIOGraph, public GRUGraphlet +{ +public: + FuseGRUTestNegGraph() = default; + + void init(void) + { + TestIOGraph::init({1}, {1}); + GRUGraphlet::init(g()); + + invalid_less_const_type(); + + _while_node->input(0, _time_node); + _while_node->input(1, _time_node); + _while_node->input(2, _state_node); + _while_node->input(3, _hidden_node); + _while_node->input(4, input()); + + _while_node->cond_graph(_cond_graph.get()); + _while_node->body_graph(_body_graph.get()); + + _while_out_node->input(_while_node); + output()->from(_while_out_node); + + // cond graph + _less_node->x(_cond_input_node); + _less_node->y(_less_const_node); + _cond_output_node->from(_less_node); + + // body graph + _add_node_1->x(_body_input_node_1); + _add_node_2->x(_body_input_node_2); + + _body_output_node_5->from(_add_node_1); + _body_output_node_4->from(_add_node_2); + + _gather_node->params(_body_input_node_2); + _fc_node_1->input(_body_input_node_4); + _fc_node_1->weights(_fc_weight_1); + _fc_node_1->bias(_fc_bias_1); + _fc_node_2->input(_gather_node); + _fc_node_2->weights(_fc_weight_2); + _fc_node_2->bias(_fc_bias_2); + + _split_node_1->input(_fc_node_1); + _split_node_2->input(_fc_node_2); + + _split_out_node_1->input(_split_node_1); + _split_out_node_2->input(_split_node_1); + _split_out_node_3->input(_split_node_1); + + _split_out_node_4->input(_split_node_2); + _split_out_node_5->input(_split_node_2); + _split_out_node_6->input(_split_node_2); + + _add_node_3->x(_split_out_node_1); + _add_node_3->y(_split_out_node_4); + + _add_node_4->x(_split_out_node_3); + _add_node_4->y(_split_out_node_6); + + _logistic_node_1->x(_add_node_3); + + _mul_node_1->x(_body_input_node_4); + _mul_node_1->y(_logistic_node_1); + + _sub_node->y(_logistic_node_1); + + _logistic_node_2->x(_add_node_4); + + _mul_node_2->x(_split_out_node_2); + _mul_node_2->y(_logistic_node_2); + + _add_node_5->x(_split_out_node_5); + _add_node_5->y(_mul_node_2); + + _tanh_node->x(_add_node_5); + + _mul_node_3->x(_sub_node); + _mul_node_3->y(_tanh_node); + + _add_node_6->x(_mul_node_1); + _add_node_6->y(_mul_node_3); + + _reshape_node->shape(_add_node_6); + + _body_output_node_3->from(_reshape_node); + } +}; + +} // namespace + +TEST(FuseGRUPassTest, fuse_pattern1) +{ + FuseGRUTestGraph1 g; + luci::FuseGRUPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(FuseGRUPassTest, fuse_NEG) +{ + FuseGRUTestNegGraph g; + luci::FuseGRUPass pass; + + g.init(); + + EXPECT_FALSE(pass.run(g.g())); +} From ec29daa017c3ce7fcb041775efb7c8031a187da2 Mon Sep 17 00:00:00 2001 From: Chunseok Lee Date: Mon, 28 Oct 2024 14:30:13 +0900 Subject: [PATCH 2/4] initialization on test class --- compiler/luci/pass/src/FuseGRUPass.test.cpp | 116 ++++++++++---------- 1 file changed, 58 insertions(+), 58 deletions(-) diff --git a/compiler/luci/pass/src/FuseGRUPass.test.cpp b/compiler/luci/pass/src/FuseGRUPass.test.cpp index 93909ea673f..bb9df366606 100644 --- a/compiler/luci/pass/src/FuseGRUPass.test.cpp +++ b/compiler/luci/pass/src/FuseGRUPass.test.cpp @@ -148,64 +148,64 @@ class GRUGraphlet void invalid_less_const_type() { _less_const_node->dtype(loco::DataType::S16); } protected: - luci::CircleWhile *_while_node; - luci::CircleWhileOut *_while_out_node; - luci::CircleConst *_time_node; - luci::CircleConst *_state_node; - luci::CircleConst *_hidden_node; - - luci::CircleInput *_cond_input_node; - luci::CircleLess *_less_node; - luci::CircleConst *_less_const_node; - luci::CircleOutput *_cond_output_node; - - luci::CircleInput *_body_input_node_1; - luci::CircleInput *_body_input_node_2; - luci::CircleInput *_body_input_node_3; - luci::CircleInput *_body_input_node_4; - luci::CircleInput *_body_input_node_5; - - luci::CircleOutput *_body_output_node_1; - luci::CircleOutput *_body_output_node_2; - luci::CircleOutput *_body_output_node_3; - luci::CircleOutput *_body_output_node_4; - luci::CircleOutput *_body_output_node_5; - - luci::CircleAdd *_add_node_1; - luci::CircleAdd *_add_node_2; - luci::CircleAdd *_add_node_3; - luci::CircleAdd *_add_node_4; - luci::CircleAdd *_add_node_5; - luci::CircleAdd *_add_node_6; - - luci::CircleMul *_mul_node_1; - luci::CircleMul *_mul_node_2; - luci::CircleMul *_mul_node_3; - - luci::CircleSub *_sub_node; - luci::CircleTanh *_tanh_node; - luci::CircleReshape *_reshape_node; - luci::CircleGather *_gather_node; - luci::CircleLogistic *_logistic_node_1; - luci::CircleLogistic *_logistic_node_2; - luci::CircleSplit *_split_node_1; - luci::CircleSplit *_split_node_2; - - luci::CircleSplitOut *_split_out_node_1; - luci::CircleSplitOut *_split_out_node_2; - luci::CircleSplitOut *_split_out_node_3; - luci::CircleSplitOut *_split_out_node_4; - luci::CircleSplitOut *_split_out_node_5; - luci::CircleSplitOut *_split_out_node_6; - - luci::CircleFullyConnected *_fc_node_1; - luci::CircleFullyConnected *_fc_node_2; - - luci::CircleConst *_split_const; - luci::CircleConst *_fc_weight_1; - luci::CircleConst *_fc_bias_1; - luci::CircleConst *_fc_weight_2; - luci::CircleConst *_fc_bias_2; + luci::CircleWhile *_while_node = nullptr; + luci::CircleWhileOut *_while_out_node = nullptr; + luci::CircleConst *_time_node = nullptr; + luci::CircleConst *_state_node = nullptr; + luci::CircleConst *_hidden_node = nullptr; + + luci::CircleInput *_cond_input_node = nullptr; + luci::CircleLess *_less_node = nullptr; + luci::CircleConst *_less_const_node = nullptr; + luci::CircleOutput *_cond_output_node = nullptr; + + luci::CircleInput *_body_input_node_1 = nullptr; + luci::CircleInput *_body_input_node_2 = nullptr; + luci::CircleInput *_body_input_node_3 = nullptr; + luci::CircleInput *_body_input_node_4 = nullptr; + luci::CircleInput *_body_input_node_5 = nullptr; + + luci::CircleOutput *_body_output_node_1 = nullptr; + luci::CircleOutput *_body_output_node_2 = nullptr; + luci::CircleOutput *_body_output_node_3 = nullptr; + luci::CircleOutput *_body_output_node_4 = nullptr; + luci::CircleOutput *_body_output_node_5 = nullptr; + + luci::CircleAdd *_add_node_1 = nullptr; + luci::CircleAdd *_add_node_2 = nullptr; + luci::CircleAdd *_add_node_3 = nullptr; + luci::CircleAdd *_add_node_4 = nullptr; + luci::CircleAdd *_add_node_5 = nullptr; + luci::CircleAdd *_add_node_6 = nullptr; + + luci::CircleMul *_mul_node_1 = nullptr; + luci::CircleMul *_mul_node_2 = nullptr; + luci::CircleMul *_mul_node_3 = nullptr; + + luci::CircleSub *_sub_node = nullptr; + luci::CircleTanh *_tanh_node = nullptr; + luci::CircleReshape *_reshape_node = nullptr; + luci::CircleGather *_gather_node = nullptr; + luci::CircleLogistic *_logistic_node_1 = nullptr; + luci::CircleLogistic *_logistic_node_2 = nullptr; + luci::CircleSplit *_split_node_1 = nullptr; + luci::CircleSplit *_split_node_2 = nullptr; + + luci::CircleSplitOut *_split_out_node_1 = nullptr; + luci::CircleSplitOut *_split_out_node_2 = nullptr; + luci::CircleSplitOut *_split_out_node_3 = nullptr; + luci::CircleSplitOut *_split_out_node_4 = nullptr; + luci::CircleSplitOut *_split_out_node_5 = nullptr; + luci::CircleSplitOut *_split_out_node_6 = nullptr; + + luci::CircleFullyConnected *_fc_node_1 = nullptr; + luci::CircleFullyConnected *_fc_node_2 = nullptr; + + luci::CircleConst *_split_const = nullptr; + luci::CircleConst *_fc_weight_1 = nullptr; + luci::CircleConst *_fc_bias_1 = nullptr; + luci::CircleConst *_fc_weight_2 = nullptr; + luci::CircleConst *_fc_bias_2 = nullptr; std::unique_ptr _cond_graph; std::unique_ptr _body_graph; From de1f543d4ec7b401641392b0b1c342442f2b7578 Mon Sep 17 00:00:00 2001 From: chunseoklee Date: Mon, 28 Oct 2024 14:34:03 +0900 Subject: [PATCH 3/4] Update compiler/luci/pass/src/FuseGRUPass.cpp Co-authored-by: SaeHie Park --- compiler/luci/pass/src/FuseGRUPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/luci/pass/src/FuseGRUPass.cpp b/compiler/luci/pass/src/FuseGRUPass.cpp index 2f1f2d341ef..ab98afdee3c 100644 --- a/compiler/luci/pass/src/FuseGRUPass.cpp +++ b/compiler/luci/pass/src/FuseGRUPass.cpp @@ -104,7 +104,7 @@ class GRUPatternBase * | [Split_1] [FullyConnected_2] * | / | \ | * | | | \ [Split_2] - * | [Add_1] -------+----+---------------------------------/ | | + * | [Add_1] ----------------------------------------------/ | | * | | | | | | * | | | ------------------------------------[Add_4] | * | | | | | From 604fb70d4a18b86256fb750a5147bab1cf244d5b Mon Sep 17 00:00:00 2001 From: Chunseok Lee Date: Mon, 28 Oct 2024 19:16:06 +0900 Subject: [PATCH 4/4] throw instead of assert --- compiler/luci/pass/src/FuseGRUPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/luci/pass/src/FuseGRUPass.cpp b/compiler/luci/pass/src/FuseGRUPass.cpp index ab98afdee3c..c6f0d58c8fd 100644 --- a/compiler/luci/pass/src/FuseGRUPass.cpp +++ b/compiler/luci/pass/src/FuseGRUPass.cpp @@ -551,7 +551,7 @@ luci::CircleConst *clone_circleconst(luci::CircleConst *node, loco::Graph *graph break; default: - assert(false); + throw std::runtime_error("Unsupported data type"); } }