Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[luci/pass] Introduce FuseGRU Pass #14252

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

chunseoklee
Copy link
Contributor

This PR introduces FuseGRUPass for fusing decomposed gru pattern into single CircleGRU.

draft : #14237
issue : #12263

ONE-DCO-1.0-Signed-off-by: Artem Balyshev a.balyshev@samsung.com
ONE-DCO-1.0-Signed-off-by: Chunseok Lee chunseok.lee@samsung.com

This PR introduces FuseGRUPass for fusing decomposed gru pattern into single CircleGRU.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <a.balyshev@samsung.com>
ONE-DCO-1.0-Signed-off-by: Chunseok Lee <chunseok.lee@samsung.com>
if (_while_node == nullptr)
return false;

// 1 - check condition graph: only one Less operation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only one Less operation and below condition doesn't match.

  • 1/ fix comment like Less operation should exist
  • 2/ fix implementation to check only one Less exist

break;
}

// doesn't find Less node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, this comment can be removed

Comment on lines +270 to +274
if (fc_nodes.size() != 2 or mul_nodes.size() != 3 or logistic_nodes.size() != 2 or
split_nodes.size() != 2 or add_nodes.size() != 6 or gather_nodes.size() != 1 or
reshape_nodes.size() != 1 or sub_nodes.size() != 1 or tanh_nodes.size() != 1 or
split_out_nodes.size() != 6)
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I commented in the draft code, I don't agree on only check number of Ops to check.

split_out_nodes.size() != 6)
return false;

// Check structure
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the algorithm of below check codes.
Please explain what is happening.

luci::CircleGRU *create_circle_gru(loco::Graph *graph);

private:
const GRUPatternBase *_p;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const GRUPatternBase *_p;
// initialized at ctor
const GRUPatternBase *_p;

break;

default:
assert(false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's throw instead of assert

}
else
{
bias_ih_cloned = _p->_pattern_last_node->graph()->nodes()->create<luci::CircleOutputExclude>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bias_ih_cloned = _p->_pattern_last_node->graph()->nodes()->create<luci::CircleOutputExclude>();
bias_ih_cloned = graph->nodes()->create<luci::CircleOutputExclude>();

?


luci::CircleGRU *FuseGRU::create_circle_gru(loco::Graph *graph)
{
assert(graph);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this graph and _p->_pattern_last_node->graph() are different, please add a not about this.
As I understand, we're looking with multiple sub graph objects so graph ptr can be different.

void invalid_less_const_type() { _less_const_node->dtype(loco::DataType::S16); }

protected:
luci::CircleWhile *_while_node;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
luci::CircleWhile *_while_node;
luci::CircleWhile *_while_node = nullptr;

and others in below too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated as suggested.

* |
* [Out_1]
*/
class GRUPattern1 final : public GRUPatternBase
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume 1 suffix in GRUPattern1 is to prepare more patterns.
Please leave a note about thus.
Or please use just GRUPattern here and later we can rename this when we add more patterns.


g.init();

EXPECT_FALSE(pass.run(g.g()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding a note about what negative would help understanding FuseGRUTestNegGraph.
It's a bunch of nodes connected and I can't catch what it's doing.

@@ -551,7 +551,7 @@ luci::CircleConst *clone_circleconst(luci::CircleConst *node, loco::Graph *graph
break;

default:
assert(false);
throw std::runtime_error("Unsupported data type");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
throw std::runtime_error("Unsupported data type");
throw std::runtime_error("FuseGRU: Unsupported data type");

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants