You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by li...@apache.org on 2021/03/05 21:47:28 UTC

[tvm] branch main updated: [Relay][Pass] Avoid stack overflow when using PostOrderRewrite (#7588)

This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1ae4697  [Relay][Pass] Avoid stack overflow when using PostOrderRewrite (#7588)
1ae4697 is described below

commit 1ae469789342e12b685f216f4e64e199accb0f47
Author: Huang, Guangtai <hg...@foxmail.com>
AuthorDate: Sat Mar 6 05:47:08 2021 +0800

    [Relay][Pass] Avoid stack overflow when using PostOrderRewrite (#7588)
    
    * init
    
    * fix
    
    * fix
---
 src/relay/ir/expr_functor.cc | 30 ++++++++++++++++++++++++++++++
 1 file changed, 30 insertions(+)

diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc
index d70c6fe..5984a20 100644
--- a/src/relay/ir/expr_functor.cc
+++ b/src/relay/ir/expr_functor.cc
@@ -103,11 +103,41 @@ Expr MixedModeMutator::VisitExpr(const Expr& expr) {
 class PostOrderRewriter : public MixedModeMutator {
  public:
   explicit PostOrderRewriter(ExprRewriter* rewriter) : rewriter_(rewriter) {}
+
   Expr DispatchVisitExpr(const Expr& expr) final {
     auto post = ExprFunctor::VisitExpr(expr);
     return rewriter_->Rewrite(expr, post);
   }
 
+  using MixedModeMutator::VisitExpr_;
+
+  Expr VisitExpr_(const LetNode* node) final {
+    auto pre_visit = [this](const LetNode* op) {
+      Expr var = this->Mutate(op->var);
+      Expr value = this->Mutate(op->value);
+    };
+    auto post_visit = [this, node](const LetNode* op) {
+      Var var = Downcast<Var>(this->Mutate(op->var));
+      Expr value = this->Mutate(op->value);
+      Expr body = this->Mutate(op->body);
+      Expr expr = GetRef<Expr>(op);
+      Expr post;
+      if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
+        post = expr;
+      } else {
+        post = Let(var, value, body);
+      }
+      //  avoid rewriting the first LetNode twice
+      if (op == node) {
+        this->memo_[expr] = post;
+      } else {
+        this->memo_[expr] = this->rewriter_->Rewrite(expr, post);
+      }
+    };
+    ExpandANormalForm(node, pre_visit, post_visit);
+    return memo_[GetRef<Expr>(node)];
+  }
+
  protected:
   ExprRewriter* rewriter_;
 };