You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by li...@apache.org on 2021/03/05 21:47:28 UTC
[tvm] branch main updated: [Relay][Pass] Avoid stack overflow when
using PostOrderRewrite (#7588)
This is an automated email from the ASF dual-hosted git repository.
liuyizhi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1ae4697 [Relay][Pass] Avoid stack overflow when using PostOrderRewrite (#7588)
1ae4697 is described below
commit 1ae469789342e12b685f216f4e64e199accb0f47
Author: Huang, Guangtai <hg...@foxmail.com>
AuthorDate: Sat Mar 6 05:47:08 2021 +0800
[Relay][Pass] Avoid stack overflow when using PostOrderRewrite (#7588)
* init
* fix
* fix
---
src/relay/ir/expr_functor.cc | 30 ++++++++++++++++++++++++++++++
1 file changed, 30 insertions(+)
diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc
index d70c6fe..5984a20 100644
--- a/src/relay/ir/expr_functor.cc
+++ b/src/relay/ir/expr_functor.cc
@@ -103,11 +103,41 @@ Expr MixedModeMutator::VisitExpr(const Expr& expr) {
class PostOrderRewriter : public MixedModeMutator {
public:
explicit PostOrderRewriter(ExprRewriter* rewriter) : rewriter_(rewriter) {}
+
Expr DispatchVisitExpr(const Expr& expr) final {
auto post = ExprFunctor::VisitExpr(expr);
return rewriter_->Rewrite(expr, post);
}
+ using MixedModeMutator::VisitExpr_;
+
+ Expr VisitExpr_(const LetNode* node) final {
+ auto pre_visit = [this](const LetNode* op) {
+ Expr var = this->Mutate(op->var);
+ Expr value = this->Mutate(op->value);
+ };
+ auto post_visit = [this, node](const LetNode* op) {
+ Var var = Downcast<Var>(this->Mutate(op->var));
+ Expr value = this->Mutate(op->value);
+ Expr body = this->Mutate(op->body);
+ Expr expr = GetRef<Expr>(op);
+ Expr post;
+ if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
+ post = expr;
+ } else {
+ post = Let(var, value, body);
+ }
+ // avoid rewriting the first LetNode twice
+ if (op == node) {
+ this->memo_[expr] = post;
+ } else {
+ this->memo_[expr] = this->rewriter_->Rewrite(expr, post);
+ }
+ };
+ ExpandANormalForm(node, pre_visit, post_visit);
+ return memo_[GetRef<Expr>(node)];
+ }
+
protected:
ExprRewriter* rewriter_;
};