You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/11/20 00:26:43 UTC

[GitHub] [tvm] mbs-octoml opened a new pull request #9542: Prepare DeadCodeElimination for running post LowerTEPass/ManifestAlloc.

mbs-octoml opened a new pull request #9542:
URL: https://github.com/apache/tvm/pull/9542


   As part of #9483 we need to prepare some critical Relay passes for running after
   lowering and conversion to DPS. For DCE we need to make sure we never remove
   side-effecting let-bound expressions, such as for allocation or evaluation of
   an external function with unknown effectfulness.
   
   Introduce a new purity pre-pass. It makes a half-hearted attempt at accounting
   for functions by tracking both 'eval' and 'call' purity, but must fallback to
   assuming call-impurity in more difficult cases (eg calling a function passed as
   a parameter, calling a function projected from a tuple, etc). However it seems
   plenty good enough.
   
   Purity must also be accounted for when determining the usage count of let-bound
   variables, so reworked that. Collapsed the let-bound value accumulation pass into
   the usage counting pass to make up for inserting the new purity analysis pass.
   
   Thanks for contributing to TVM!   Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @ them in the pull request thread.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9542: [Relay] Prepare DeadCodeElimination for running post LowerTEPass/ManifestAlloc.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9542:
URL: https://github.com/apache/tvm/pull/9542#discussion_r754743614



##########
File path: src/relay/op/memory/memory.h
##########
@@ -35,6 +35,7 @@ namespace tvm {
 namespace relay {
 
 Expr AllocStorage(Expr size, Expr alignment, SEScope se_scope, DataType dtype_hint);
+const Op& MemoryAllocTensorOp();

Review comment:
       Oh it snuck in from the main change, sorry. Added comment, just trying to cut down on the duplicated special operator name strings.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #9542: [Relay] Prepare DeadCodeElimination for running post LowerTEPass/ManifestAlloc.

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #9542:
URL: https://github.com/apache/tvm/pull/9542#discussion_r754699874



##########
File path: src/relay/op/memory/memory.h
##########
@@ -35,6 +35,7 @@ namespace tvm {
 namespace relay {
 
 Expr AllocStorage(Expr size, Expr alignment, SEScope se_scope, DataType dtype_hint);
+const Op& MemoryAllocTensorOp();

Review comment:
       This looks kind of odd, can you provide a comment at least?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mikepapadim commented on pull request #9542: [Relay] Prepare DeadCodeElimination for running post LowerTEPass/ManifestAlloc.

Posted by GitBox <gi...@apache.org>.
mikepapadim commented on pull request #9542:
URL: https://github.com/apache/tvm/pull/9542#issuecomment-975996959


   LGTM
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9542: [Relay] Prepare DeadCodeElimination for running post LowerTEPass/ManifestAlloc.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9542:
URL: https://github.com/apache/tvm/pull/9542#discussion_r754747093



##########
File path: src/relay/transforms/dead_code.cc
##########
@@ -18,158 +18,565 @@
  */
 
 /*!
+ * \file src/relay/transforms/dead_code.cc
+ * \brief Elides or inlines let-bindings.
  *
- * \file dead_code.cc
- *
- * \brief Remove code that does not effect the program result.
- *
- * The algorithm is implemented by two visitor:
- * CalcDep turn an expr into a dependency graph of expr,
- * GenLet turn the dependency graph into a let list, taking only the used value.
+ * TODO(mbs): Track dead writes into references.
  */
+
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/pattern_functor.h>
 #include <tvm/relay/transform.h>
 
-#include "let_list.h"
+#include "../op/call/call.h"
 
 namespace tvm {
 namespace relay {
+namespace {
 
-template <typename X>
-using VarMap = std::unordered_map<Var, X, ObjectPtrHash, ObjectPtrEqual>;
-using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
+/*! \brief Maximum depth of calls to analyize. */
+constexpr int kMaxCallDepth = 25;
+
+/*!
+ * \brief Captures (an approximation of) the purity for a Relay sub-expression. A pure
+ * sub-expression is guaranteed never to access or mutate state. Thus the sub-expression
+ * can safely be elided (if its result is never used), or inlined (which may change the
+ * number of times and program order for the evaluation.)
+ */
+struct Purity {
+  /*!
+   * \brief True if evaling the sub-expression itself is pure.
+   */
+  bool pure_eval;
+  /*!
+   * \brief If the sub-expression is first-order then always true. Otherwise true only if evaling
+   * a call to the the sub-expression is pure. See [RULE A] below.
+   */
+  bool pure_call;
+};
+
+/*!
+ * \brief Visits all the global functions in a module and records the purity of every let-bound
+ * value.
+ *
+ * (See also inline.cc for function inlining.)
+ *
+ * Generally we track whether evaluation of a sub-expression is definitely pure. However for
+ * sub-expressions f of higher-order type we also track the 'call purity' of evaling a call to f:
+ *  - [RULE A] If f's result is itself higher-order then f is call-pure only if the result of f is
+ *    also call-pure.
+ *  - [RULE B] Higher-order function arguments are assumed call impure.
+ *  - [RULE C] We assume functions extracted from tuples are call impure.
+ *  - [RULE D] We assume functions extracted from references are call impure.
+ *  - [RULE E] We assume functions extracted from ADTs are call impure.
+ *  - [RULE F] We assume all external Functions and PrimFuncs are call impure.
+ */
+class PurityVisitor : ExprFunctor<Purity(const Expr&)> {
+ public:
+  explicit PurityVisitor(IRModule mod) : mod_(std::move(mod)), current_call_depth_(0) {}
+
+  /*! \brief Visit all the functions in the module. */
+  void VisitModule() {
+    VLOG_CONTEXT << "PurityVisitor";
+    // It is safe to visit the global functions in any order. Recursive global functions are
+    // allowed.
+    for (const auto& kv : mod_->functions) {
+      if (const auto* function_node = kv.second.as<FunctionNode>()) {
+        if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
+            function_node->GetAttr<String>(attr::kExternalSymbol)) {
+          // Ignore primitive and external functions.
+          continue;
+        }
+        // Everything of interest will be recorded in the purity maps so we ignore the result.
+        (void)VisitGlobalFunction(kv.first, GetRef<Function>(function_node));
+      }
+    }
+  }
+
+  /*!
+   * \brief Returns a map from every let-bound variable to whether its let-bound value is
+   * definitely pure.
+   */
+  std::unordered_map<const VarNode*, bool> GetPurityMap() const {
+    std::unordered_map<const VarNode*, bool> result;
+    for (const auto& kv : var_to_purity_) {
+      result.emplace(kv.first, kv.second.pure_eval);
+    }
+    return result;
+  }
 
-class CalcDep;
-class FindDef : private ExprVisitor {
  private:
-  VarMap<Expr> expr_map_;
+  Purity VisitExpr(const Expr& expr) final {
+    auto it = memo_.find(expr.get());
+    if (it != this->memo_.end()) {
+      return it->second;
+    } else {
+      Purity result = ExprFunctor::VisitExpr(expr);
+      memo_[expr.get()] = result;
+      return result;
+    }
+  }
 
-  void VisitExpr_(const LetNode* l) final {
-    auto pre_visit = [this](const LetNode* op) {
-      ICHECK_EQ(expr_map_.count(op->var), 0);
-      expr_map_[op->var] = op->value;
-      this->VisitExpr(op->value);
-    };
-    auto post_visit = [this](const LetNode* op) {
-      this->VisitExpr(op->body);
-      this->visit_counter_[op] += 1;
-    };
-    ExpandANormalForm(l, pre_visit, post_visit);
+  Purity VisitExpr_(const ConstantNode*) final { return {/*pure_eval=*/true, /*pure_call=*/true}; }
+
+  Purity VisitExpr_(const ConstructorNode*) final {
+    return {/*pure_eval=*/true, /*pure_call=*/true};
+  }
+
+  Purity VisitExpr_(const OpNode* op_node) final {
+    // Primitive operators are pure unless marked as 'stateful'.
+    static OpAttrMap<bool> attr_map = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
+    bool is_statefull = attr_map.count(GetRef<Op>(op_node)) && attr_map[GetRef<Op>(op_node)];

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9542: [Relay] Prepare DeadCodeElimination for running post LowerTEPass/ManifestAlloc.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9542:
URL: https://github.com/apache/tvm/pull/9542#discussion_r754747363



##########
File path: src/relay/transforms/dead_code.cc
##########
@@ -18,158 +18,565 @@
  */
 
 /*!
+ * \file src/relay/transforms/dead_code.cc
+ * \brief Elides or inlines let-bindings.
  *
- * \file dead_code.cc
- *
- * \brief Remove code that does not effect the program result.
- *
- * The algorithm is implemented by two visitor:
- * CalcDep turn an expr into a dependency graph of expr,
- * GenLet turn the dependency graph into a let list, taking only the used value.
+ * TODO(mbs): Track dead writes into references.
  */
+
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/pattern_functor.h>
 #include <tvm/relay/transform.h>
 
-#include "let_list.h"
+#include "../op/call/call.h"
 
 namespace tvm {
 namespace relay {
+namespace {
 
-template <typename X>
-using VarMap = std::unordered_map<Var, X, ObjectPtrHash, ObjectPtrEqual>;
-using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
+/*! \brief Maximum depth of calls to analyize. */
+constexpr int kMaxCallDepth = 25;
+
+/*!
+ * \brief Captures (an approximation of) the purity for a Relay sub-expression. A pure
+ * sub-expression is guaranteed never to access or mutate state. Thus the sub-expression
+ * can safely be elided (if its result is never used), or inlined (which may change the
+ * number of times and program order for the evaluation.)
+ */
+struct Purity {
+  /*!
+   * \brief True if evaling the sub-expression itself is pure.
+   */
+  bool pure_eval;
+  /*!
+   * \brief If the sub-expression is first-order then always true. Otherwise true only if evaling
+   * a call to the the sub-expression is pure. See [RULE A] below.
+   */
+  bool pure_call;
+};
+
+/*!
+ * \brief Visits all the global functions in a module and records the purity of every let-bound
+ * value.
+ *
+ * (See also inline.cc for function inlining.)
+ *
+ * Generally we track whether evaluation of a sub-expression is definitely pure. However for
+ * sub-expressions f of higher-order type we also track the 'call purity' of evaling a call to f:
+ *  - [RULE A] If f's result is itself higher-order then f is call-pure only if the result of f is
+ *    also call-pure.
+ *  - [RULE B] Higher-order function arguments are assumed call impure.
+ *  - [RULE C] We assume functions extracted from tuples are call impure.
+ *  - [RULE D] We assume functions extracted from references are call impure.
+ *  - [RULE E] We assume functions extracted from ADTs are call impure.
+ *  - [RULE F] We assume all external Functions and PrimFuncs are call impure.
+ */
+class PurityVisitor : ExprFunctor<Purity(const Expr&)> {
+ public:
+  explicit PurityVisitor(IRModule mod) : mod_(std::move(mod)), current_call_depth_(0) {}
+
+  /*! \brief Visit all the functions in the module. */
+  void VisitModule() {
+    VLOG_CONTEXT << "PurityVisitor";
+    // It is safe to visit the global functions in any order. Recursive global functions are
+    // allowed.
+    for (const auto& kv : mod_->functions) {
+      if (const auto* function_node = kv.second.as<FunctionNode>()) {
+        if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
+            function_node->GetAttr<String>(attr::kExternalSymbol)) {
+          // Ignore primitive and external functions.
+          continue;
+        }
+        // Everything of interest will be recorded in the purity maps so we ignore the result.
+        (void)VisitGlobalFunction(kv.first, GetRef<Function>(function_node));
+      }
+    }
+  }
+
+  /*!
+   * \brief Returns a map from every let-bound variable to whether its let-bound value is
+   * definitely pure.
+   */
+  std::unordered_map<const VarNode*, bool> GetPurityMap() const {
+    std::unordered_map<const VarNode*, bool> result;
+    for (const auto& kv : var_to_purity_) {
+      result.emplace(kv.first, kv.second.pure_eval);
+    }
+    return result;
+  }
 
-class CalcDep;
-class FindDef : private ExprVisitor {
  private:
-  VarMap<Expr> expr_map_;
+  Purity VisitExpr(const Expr& expr) final {
+    auto it = memo_.find(expr.get());
+    if (it != this->memo_.end()) {
+      return it->second;
+    } else {
+      Purity result = ExprFunctor::VisitExpr(expr);
+      memo_[expr.get()] = result;
+      return result;
+    }
+  }
 
-  void VisitExpr_(const LetNode* l) final {
-    auto pre_visit = [this](const LetNode* op) {
-      ICHECK_EQ(expr_map_.count(op->var), 0);
-      expr_map_[op->var] = op->value;
-      this->VisitExpr(op->value);
-    };
-    auto post_visit = [this](const LetNode* op) {
-      this->VisitExpr(op->body);
-      this->visit_counter_[op] += 1;
-    };
-    ExpandANormalForm(l, pre_visit, post_visit);
+  Purity VisitExpr_(const ConstantNode*) final { return {/*pure_eval=*/true, /*pure_call=*/true}; }
+
+  Purity VisitExpr_(const ConstructorNode*) final {
+    return {/*pure_eval=*/true, /*pure_call=*/true};
+  }
+
+  Purity VisitExpr_(const OpNode* op_node) final {
+    // Primitive operators are pure unless marked as 'stateful'.
+    static OpAttrMap<bool> attr_map = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
+    bool is_statefull = attr_map.count(GetRef<Op>(op_node)) && attr_map[GetRef<Op>(op_node)];
+    return {/*pure_eval=*/true, /*pure_call=*/!is_statefull};
+  }
+
+  Purity VisitExpr_(const GlobalVarNode* global_var_node) final {
+    auto global_var = GetRef<GlobalVar>(global_var_node);
+    auto func = mod_->Lookup(global_var);
+    if (const auto* function_node = func.as<FunctionNode>()) {
+      if (!function_node->GetAttr<String>(attr::kExternalSymbol)) {
+        return VisitGlobalFunction(global_var, GetRef<Function>(function_node));
+      }
+    }
+    // Assume externals and PrimFuncs are call-impure [RULE F].
+    // (If they are pure then we should have dealt with them before lowering.)
+    return {/*pure_eval==*/true, /*pure_call=*/false};
+  }
+
+  Purity VisitExpr_(const VarNode* var_node) final {
+    // The var is bound to a value, but if that value is a function we need to propagate the
+    // function body's purity.
+    ICHECK(var_to_purity_.count(var_node)) << PrettyPrint(GetRef<Var>(var_node));
+    return {/*pure_eval=*/true, /*pure_call=*/var_to_purity_[var_node].pure_call};
+  }
+
+  Purity VisitExpr_(const FunctionNode* function_node) final {
+    for (const auto& param : function_node->params) {
+      // Any higher-order parameters are assumed to be call-impure [RULE B]
+      var_to_purity_[param.get()] = {/*pure_eval=*/true, /*pure_call=*/IsFirstOrder(param)};
+    }
+    Purity body_purity = VisitExpr(function_node->body);
+    // The function itself is a value and thus pure. If the function returns
+    // a function we'll fold its purity in here [RULE A]
+    return {/*pure_eval=*/true, /*pure_call=*/body_purity.pure_eval && body_purity.pure_call};
+  }
+
+  Purity VisitExpr_(const LetNode* let_node) final {
+    Expr expr = GetRef<Expr>(let_node);
+    bool all_values_pure_eval = true;
+    while (const auto* inner_let_node = expr.as<LetNode>()) {
+      // In case the value is a recursive function assume the let-bound variable is call-pure.
+      var_to_purity_[inner_let_node->var.get()] = {/*pure_eval=*/true, /*pure_call=*/true};
+      Purity value_purity = VisitExpr(inner_let_node->value);
+      // Now revise the variable to it's true purity.
+      var_to_purity_[inner_let_node->var.get()] = value_purity;
+      VLOG(2) << (value_purity.pure_eval ? "pure" : "impure") << " expression:" << std::endl
+              << PrettyPrint(inner_let_node->value) << std::endl
+              << "let-bound to variable:" << std::endl
+              << PrettyPrint(inner_let_node->var);
+      all_values_pure_eval = all_values_pure_eval && value_purity.pure_eval;
+      expr = inner_let_node->body;
+    }
+    Purity body_purity = VisitExpr(expr);
+    return {/*pure_eval=*/all_values_pure_eval && body_purity.pure_eval,
+            /*pure_call=*/body_purity.pure_call};
+  }
+
+  Purity VisitExpr_(const CallNode* call_node) final {
+    if (current_call_depth_ >= kMaxCallDepth) {
+      // Assume impure.
+      VLOG(2) << "assuming call is impure since too deeply nested";
+      return {/*pure_eval=*/false, /*pure_call*/ IsFirstOrder(GetRef<Call>(call_node))};
+    }
+
+    ++current_call_depth_;
+
+    // We can work with the call in both pre- and post-lowered form.
+    Expr callee;
+    Array<Expr> args;
+    if (call_node->op == CallLoweredOp()) {
+      CallLoweredProps props = GetCallLoweredProps(call_node);
+      callee = props.lowered_func;
+      args = props.arguments;
+    } else {
+      callee = call_node->op;
+      args = call_node->args;
+    }
+
+    // Find purity for the callee and the args.
+    Purity callee_purity = VisitExpr(callee);
+    bool all_args_pure_eval = true;
+    for (const auto& arg : args) {
+      Purity arg_purity = VisitExpr(arg);
+      all_args_pure_eval = all_args_pure_eval && arg_purity.pure_eval;
+    }
+
+    VLOG(2) << (callee_purity.pure_call ? "pure" : "impure") << " call to:" << std::endl
+            << PrettyPrint(callee);
+
+    ICHECK_GT(current_call_depth_, 0);
+    --current_call_depth_;
+
+    // If the callee's result is itself a function then by [RULE A] its purity
+    // is given by callee_purity.pure_call.
+    return {/*pure_eval=*/all_args_pure_eval && callee_purity.pure_eval && callee_purity.pure_call,
+            /*pure_call=*/IsFirstOrder(GetRef<Call>(call_node)) || callee_purity.pure_call};
+  }
+
+  Purity VisitExpr_(const IfNode* if_node) final {
+    Purity cond_purity = VisitExpr(if_node->cond);
+    ICHECK(cond_purity.pure_call);  // conditional is first-order
+    Purity true_purity = VisitExpr(if_node->true_branch);
+    Purity false_purity = VisitExpr(if_node->false_branch);
+    return {/*pure_eval=*/cond_purity.pure_eval && true_purity.pure_eval && false_purity.pure_eval,
+            /*pure_call=*/true_purity.pure_call && false_purity.pure_call};
+  }
+
+  Purity VisitExpr_(const TupleNode* tuple_node) final {
+    bool all_fields_pure = true;
+    for (const auto& field : tuple_node->fields) {
+      // The call purity of each tuple field is lost [RULE C].
+      Purity field_purity = VisitExpr(field);
+      if (!field_purity.pure_eval) {
+        all_fields_pure = false;
+      }
+    }
+    return {/*pure_eval=*/all_fields_pure, /*pure_call=*/true};
+  }
+
+  Purity VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final {
+    Purity tuple_purity = VisitExpr(tuple_get_item_node->tuple);
+    ICHECK(tuple_purity.pure_call);  // tuple is first-order
+    // We don't track call purity through tuple fields, so if the result is a function type we
+    // must assume it is call impure [RULE C].
+    return {/*pure_eval=*/tuple_purity.pure_eval,
+            /*pure_call=*/IsFirstOrder(GetRef<TupleGetItem>(tuple_get_item_node))};
+  }
+
+  Purity VisitExpr_(const RefCreateNode*) final {
+    // The creation of the  ref itself is unobservable other than via the reads/writes into it.
+    return {/*pure_eval=*/true, /*pure_call=*/true};
+  }
+
+  Purity VisitExpr_(const RefWriteNode* ref_write_node) final {
+    Purity ref_purity = VisitExpr(ref_write_node->ref);
+    ICHECK(ref_purity.pure_call);  // reference is first-order
+    // The call purity of the written value is lost [RULE D].

Review comment:
       added a comment to explain -- gotta accumulate any let-bindings inside the value.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #9542: [Relay] Prepare DeadCodeElimination for running post LowerTEPass/ManifestAlloc.

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #9542:
URL: https://github.com/apache/tvm/pull/9542#discussion_r754716450



##########
File path: include/tvm/relay/transform.h
##########
@@ -540,7 +550,7 @@ TVM_DLL Function ToCPS(const Function& f, const IRModule& mod);
 /*!
  * \brief Remove the continuation argument of a CPS function.
  *
- * Note that this only transform the type back into un-CPS form
+ * Note that this only transform the type back into un-CPS formA

Review comment:
       You added an "A" here by accident :) 

##########
File path: src/relay/transforms/dead_code.cc
##########
@@ -18,158 +18,565 @@
  */
 
 /*!
+ * \file src/relay/transforms/dead_code.cc
+ * \brief Elides or inlines let-bindings.
  *
- * \file dead_code.cc
- *
- * \brief Remove code that does not effect the program result.
- *
- * The algorithm is implemented by two visitor:
- * CalcDep turn an expr into a dependency graph of expr,
- * GenLet turn the dependency graph into a let list, taking only the used value.
+ * TODO(mbs): Track dead writes into references.
  */
+
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/pattern_functor.h>
 #include <tvm/relay/transform.h>
 
-#include "let_list.h"
+#include "../op/call/call.h"
 
 namespace tvm {
 namespace relay {
+namespace {
 
-template <typename X>
-using VarMap = std::unordered_map<Var, X, ObjectPtrHash, ObjectPtrEqual>;
-using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
+/*! \brief Maximum depth of calls to analyize. */
+constexpr int kMaxCallDepth = 25;
+
+/*!
+ * \brief Captures (an approximation of) the purity for a Relay sub-expression. A pure
+ * sub-expression is guaranteed never to access or mutate state. Thus the sub-expression
+ * can safely be elided (if its result is never used), or inlined (which may change the
+ * number of times and program order for the evaluation.)
+ */
+struct Purity {
+  /*!
+   * \brief True if evaling the sub-expression itself is pure.
+   */
+  bool pure_eval;
+  /*!
+   * \brief If the sub-expression is first-order then always true. Otherwise true only if evaling
+   * a call to the the sub-expression is pure. See [RULE A] below.
+   */
+  bool pure_call;
+};
+
+/*!
+ * \brief Visits all the global functions in a module and records the purity of every let-bound
+ * value.
+ *
+ * (See also inline.cc for function inlining.)
+ *
+ * Generally we track whether evaluation of a sub-expression is definitely pure. However for
+ * sub-expressions f of higher-order type we also track the 'call purity' of evaling a call to f:
+ *  - [RULE A] If f's result is itself higher-order then f is call-pure only if the result of f is
+ *    also call-pure.
+ *  - [RULE B] Higher-order function arguments are assumed call impure.
+ *  - [RULE C] We assume functions extracted from tuples are call impure.
+ *  - [RULE D] We assume functions extracted from references are call impure.
+ *  - [RULE E] We assume functions extracted from ADTs are call impure.
+ *  - [RULE F] We assume all external Functions and PrimFuncs are call impure.
+ */
+class PurityVisitor : ExprFunctor<Purity(const Expr&)> {
+ public:
+  explicit PurityVisitor(IRModule mod) : mod_(std::move(mod)), current_call_depth_(0) {}
+
+  /*! \brief Visit all the functions in the module. */
+  void VisitModule() {
+    VLOG_CONTEXT << "PurityVisitor";
+    // It is safe to visit the global functions in any order. Recursive global functions are
+    // allowed.
+    for (const auto& kv : mod_->functions) {
+      if (const auto* function_node = kv.second.as<FunctionNode>()) {
+        if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
+            function_node->GetAttr<String>(attr::kExternalSymbol)) {
+          // Ignore primitive and external functions.
+          continue;
+        }
+        // Everything of interest will be recorded in the purity maps so we ignore the result.
+        (void)VisitGlobalFunction(kv.first, GetRef<Function>(function_node));
+      }
+    }
+  }
+
+  /*!
+   * \brief Returns a map from every let-bound variable to whether its let-bound value is
+   * definitely pure.
+   */
+  std::unordered_map<const VarNode*, bool> GetPurityMap() const {
+    std::unordered_map<const VarNode*, bool> result;
+    for (const auto& kv : var_to_purity_) {
+      result.emplace(kv.first, kv.second.pure_eval);
+    }
+    return result;
+  }
 
-class CalcDep;
-class FindDef : private ExprVisitor {
  private:
-  VarMap<Expr> expr_map_;
+  Purity VisitExpr(const Expr& expr) final {
+    auto it = memo_.find(expr.get());
+    if (it != this->memo_.end()) {
+      return it->second;
+    } else {
+      Purity result = ExprFunctor::VisitExpr(expr);
+      memo_[expr.get()] = result;
+      return result;
+    }
+  }
 
-  void VisitExpr_(const LetNode* l) final {
-    auto pre_visit = [this](const LetNode* op) {
-      ICHECK_EQ(expr_map_.count(op->var), 0);
-      expr_map_[op->var] = op->value;
-      this->VisitExpr(op->value);
-    };
-    auto post_visit = [this](const LetNode* op) {
-      this->VisitExpr(op->body);
-      this->visit_counter_[op] += 1;
-    };
-    ExpandANormalForm(l, pre_visit, post_visit);
+  Purity VisitExpr_(const ConstantNode*) final { return {/*pure_eval=*/true, /*pure_call=*/true}; }
+
+  Purity VisitExpr_(const ConstructorNode*) final {
+    return {/*pure_eval=*/true, /*pure_call=*/true};
+  }
+
+  Purity VisitExpr_(const OpNode* op_node) final {
+    // Primitive operators are pure unless marked as 'stateful'.
+    static OpAttrMap<bool> attr_map = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
+    bool is_statefull = attr_map.count(GetRef<Op>(op_node)) && attr_map[GetRef<Op>(op_node)];
+    return {/*pure_eval=*/true, /*pure_call=*/!is_statefull};
+  }
+
+  Purity VisitExpr_(const GlobalVarNode* global_var_node) final {
+    auto global_var = GetRef<GlobalVar>(global_var_node);
+    auto func = mod_->Lookup(global_var);
+    if (const auto* function_node = func.as<FunctionNode>()) {
+      if (!function_node->GetAttr<String>(attr::kExternalSymbol)) {
+        return VisitGlobalFunction(global_var, GetRef<Function>(function_node));
+      }
+    }
+    // Assume externals and PrimFuncs are call-impure [RULE F].
+    // (If they are pure then we should have dealt with them before lowering.)
+    return {/*pure_eval==*/true, /*pure_call=*/false};
+  }
+
+  Purity VisitExpr_(const VarNode* var_node) final {
+    // The var is bound to a value, but if that value is a function we need to propagate the
+    // function body's purity.
+    ICHECK(var_to_purity_.count(var_node)) << PrettyPrint(GetRef<Var>(var_node));
+    return {/*pure_eval=*/true, /*pure_call=*/var_to_purity_[var_node].pure_call};
+  }
+
+  Purity VisitExpr_(const FunctionNode* function_node) final {
+    for (const auto& param : function_node->params) {
+      // Any higher-order parameters are assumed to be call-impure [RULE B]
+      var_to_purity_[param.get()] = {/*pure_eval=*/true, /*pure_call=*/IsFirstOrder(param)};
+    }
+    Purity body_purity = VisitExpr(function_node->body);
+    // The function itself is a value and thus pure. If the function returns
+    // a function we'll fold its purity in here [RULE A]
+    return {/*pure_eval=*/true, /*pure_call=*/body_purity.pure_eval && body_purity.pure_call};
+  }
+
+  Purity VisitExpr_(const LetNode* let_node) final {
+    Expr expr = GetRef<Expr>(let_node);
+    bool all_values_pure_eval = true;
+    while (const auto* inner_let_node = expr.as<LetNode>()) {
+      // In case the value is a recursive function assume the let-bound variable is call-pure.
+      var_to_purity_[inner_let_node->var.get()] = {/*pure_eval=*/true, /*pure_call=*/true};
+      Purity value_purity = VisitExpr(inner_let_node->value);
+      // Now revise the variable to it's true purity.
+      var_to_purity_[inner_let_node->var.get()] = value_purity;
+      VLOG(2) << (value_purity.pure_eval ? "pure" : "impure") << " expression:" << std::endl
+              << PrettyPrint(inner_let_node->value) << std::endl
+              << "let-bound to variable:" << std::endl
+              << PrettyPrint(inner_let_node->var);
+      all_values_pure_eval = all_values_pure_eval && value_purity.pure_eval;
+      expr = inner_let_node->body;
+    }
+    Purity body_purity = VisitExpr(expr);
+    return {/*pure_eval=*/all_values_pure_eval && body_purity.pure_eval,
+            /*pure_call=*/body_purity.pure_call};
+  }
+
+  Purity VisitExpr_(const CallNode* call_node) final {
+    if (current_call_depth_ >= kMaxCallDepth) {
+      // Assume impure.
+      VLOG(2) << "assuming call is impure since too deeply nested";
+      return {/*pure_eval=*/false, /*pure_call*/ IsFirstOrder(GetRef<Call>(call_node))};
+    }
+
+    ++current_call_depth_;
+
+    // We can work with the call in both pre- and post-lowered form.
+    Expr callee;
+    Array<Expr> args;
+    if (call_node->op == CallLoweredOp()) {
+      CallLoweredProps props = GetCallLoweredProps(call_node);
+      callee = props.lowered_func;
+      args = props.arguments;
+    } else {
+      callee = call_node->op;
+      args = call_node->args;
+    }
+
+    // Find purity for the callee and the args.
+    Purity callee_purity = VisitExpr(callee);
+    bool all_args_pure_eval = true;
+    for (const auto& arg : args) {
+      Purity arg_purity = VisitExpr(arg);
+      all_args_pure_eval = all_args_pure_eval && arg_purity.pure_eval;
+    }
+
+    VLOG(2) << (callee_purity.pure_call ? "pure" : "impure") << " call to:" << std::endl
+            << PrettyPrint(callee);
+
+    ICHECK_GT(current_call_depth_, 0);
+    --current_call_depth_;
+
+    // If the callee's result is itself a function then by [RULE A] its purity
+    // is given by callee_purity.pure_call.
+    return {/*pure_eval=*/all_args_pure_eval && callee_purity.pure_eval && callee_purity.pure_call,
+            /*pure_call=*/IsFirstOrder(GetRef<Call>(call_node)) || callee_purity.pure_call};
+  }
+
+  Purity VisitExpr_(const IfNode* if_node) final {
+    Purity cond_purity = VisitExpr(if_node->cond);
+    ICHECK(cond_purity.pure_call);  // conditional is first-order
+    Purity true_purity = VisitExpr(if_node->true_branch);
+    Purity false_purity = VisitExpr(if_node->false_branch);
+    return {/*pure_eval=*/cond_purity.pure_eval && true_purity.pure_eval && false_purity.pure_eval,
+            /*pure_call=*/true_purity.pure_call && false_purity.pure_call};
+  }
+
+  Purity VisitExpr_(const TupleNode* tuple_node) final {
+    bool all_fields_pure = true;
+    for (const auto& field : tuple_node->fields) {
+      // The call purity of each tuple field is lost [RULE C].
+      Purity field_purity = VisitExpr(field);
+      if (!field_purity.pure_eval) {
+        all_fields_pure = false;
+      }
+    }
+    return {/*pure_eval=*/all_fields_pure, /*pure_call=*/true};
+  }
+
+  Purity VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final {
+    Purity tuple_purity = VisitExpr(tuple_get_item_node->tuple);
+    ICHECK(tuple_purity.pure_call);  // tuple is first-order
+    // We don't track call purity through tuple fields, so if the result is a function type we
+    // must assume it is call impure [RULE C].
+    return {/*pure_eval=*/tuple_purity.pure_eval,
+            /*pure_call=*/IsFirstOrder(GetRef<TupleGetItem>(tuple_get_item_node))};
+  }
+
+  Purity VisitExpr_(const RefCreateNode*) final {
+    // The creation of the  ref itself is unobservable other than via the reads/writes into it.
+    return {/*pure_eval=*/true, /*pure_call=*/true};
+  }
+
+  Purity VisitExpr_(const RefWriteNode* ref_write_node) final {
+    Purity ref_purity = VisitExpr(ref_write_node->ref);
+    ICHECK(ref_purity.pure_call);  // reference is first-order
+    // The call purity of the written value is lost [RULE D].

Review comment:
       if the call purity of the written value is lost, why do we need to visit it?

##########
File path: src/relay/transforms/dead_code.cc
##########
@@ -18,158 +18,565 @@
  */
 
 /*!
+ * \file src/relay/transforms/dead_code.cc
+ * \brief Elides or inlines let-bindings.
  *
- * \file dead_code.cc
- *
- * \brief Remove code that does not effect the program result.
- *
- * The algorithm is implemented by two visitor:
- * CalcDep turn an expr into a dependency graph of expr,
- * GenLet turn the dependency graph into a let list, taking only the used value.
+ * TODO(mbs): Track dead writes into references.
  */
+
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/pattern_functor.h>
 #include <tvm/relay/transform.h>
 
-#include "let_list.h"
+#include "../op/call/call.h"
 
 namespace tvm {
 namespace relay {
+namespace {
 
-template <typename X>
-using VarMap = std::unordered_map<Var, X, ObjectPtrHash, ObjectPtrEqual>;
-using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
+/*! \brief Maximum depth of calls to analyize. */
+constexpr int kMaxCallDepth = 25;
+
+/*!
+ * \brief Captures (an approximation of) the purity for a Relay sub-expression. A pure
+ * sub-expression is guaranteed never to access or mutate state. Thus the sub-expression
+ * can safely be elided (if its result is never used), or inlined (which may change the
+ * number of times and program order for the evaluation.)
+ */
+struct Purity {
+  /*!
+   * \brief True if evaling the sub-expression itself is pure.
+   */
+  bool pure_eval;
+  /*!
+   * \brief If the sub-expression is first-order then always true. Otherwise true only if evaling
+   * a call to the the sub-expression is pure. See [RULE A] below.
+   */
+  bool pure_call;
+};
+
+/*!
+ * \brief Visits all the global functions in a module and records the purity of every let-bound
+ * value.
+ *
+ * (See also inline.cc for function inlining.)
+ *
+ * Generally we track whether evaluation of a sub-expression is definitely pure. However for
+ * sub-expressions f of higher-order type we also track the 'call purity' of evaling a call to f:
+ *  - [RULE A] If f's result is itself higher-order then f is call-pure only if the result of f is
+ *    also call-pure.
+ *  - [RULE B] Higher-order function arguments are assumed call impure.
+ *  - [RULE C] We assume functions extracted from tuples are call impure.
+ *  - [RULE D] We assume functions extracted from references are call impure.
+ *  - [RULE E] We assume functions extracted from ADTs are call impure.
+ *  - [RULE F] We assume all external Functions and PrimFuncs are call impure.
+ */

Review comment:
       This description is very nice

##########
File path: src/relay/transforms/dead_code.cc
##########
@@ -18,158 +18,565 @@
  */
 
 /*!
+ * \file src/relay/transforms/dead_code.cc
+ * \brief Elides or inlines let-bindings.
  *
- * \file dead_code.cc
- *
- * \brief Remove code that does not effect the program result.
- *
- * The algorithm is implemented by two visitor:
- * CalcDep turn an expr into a dependency graph of expr,
- * GenLet turn the dependency graph into a let list, taking only the used value.
+ * TODO(mbs): Track dead writes into references.
  */
+
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/pattern_functor.h>
 #include <tvm/relay/transform.h>
 
-#include "let_list.h"
+#include "../op/call/call.h"
 
 namespace tvm {
 namespace relay {
+namespace {
 
-template <typename X>
-using VarMap = std::unordered_map<Var, X, ObjectPtrHash, ObjectPtrEqual>;
-using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
+/*! \brief Maximum depth of calls to analyize. */
+constexpr int kMaxCallDepth = 25;
+
+/*!
+ * \brief Captures (an approximation of) the purity for a Relay sub-expression. A pure
+ * sub-expression is guaranteed never to access or mutate state. Thus the sub-expression
+ * can safely be elided (if its result is never used), or inlined (which may change the
+ * number of times and program order for the evaluation.)
+ */
+struct Purity {
+  /*!
+   * \brief True if evaling the sub-expression itself is pure.
+   */
+  bool pure_eval;
+  /*!
+   * \brief If the sub-expression is first-order then always true. Otherwise true only if evaling
+   * a call to the the sub-expression is pure. See [RULE A] below.
+   */
+  bool pure_call;
+};
+
+/*!
+ * \brief Visits all the global functions in a module and records the purity of every let-bound
+ * value.
+ *
+ * (See also inline.cc for function inlining.)
+ *
+ * Generally we track whether evaluation of a sub-expression is definitely pure. However for
+ * sub-expressions f of higher-order type we also track the 'call purity' of evaling a call to f:
+ *  - [RULE A] If f's result is itself higher-order then f is call-pure only if the result of f is
+ *    also call-pure.
+ *  - [RULE B] Higher-order function arguments are assumed call impure.
+ *  - [RULE C] We assume functions extracted from tuples are call impure.
+ *  - [RULE D] We assume functions extracted from references are call impure.
+ *  - [RULE E] We assume functions extracted from ADTs are call impure.
+ *  - [RULE F] We assume all external Functions and PrimFuncs are call impure.
+ */
+class PurityVisitor : ExprFunctor<Purity(const Expr&)> {
+ public:
+  explicit PurityVisitor(IRModule mod) : mod_(std::move(mod)), current_call_depth_(0) {}
+
+  /*! \brief Visit all the functions in the module. */
+  void VisitModule() {
+    VLOG_CONTEXT << "PurityVisitor";
+    // It is safe to visit the global functions in any order. Recursive global functions are
+    // allowed.
+    for (const auto& kv : mod_->functions) {
+      if (const auto* function_node = kv.second.as<FunctionNode>()) {
+        if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
+            function_node->GetAttr<String>(attr::kExternalSymbol)) {
+          // Ignore primitive and external functions.
+          continue;
+        }
+        // Everything of interest will be recorded in the purity maps so we ignore the result.
+        (void)VisitGlobalFunction(kv.first, GetRef<Function>(function_node));
+      }
+    }
+  }
+
+  /*!
+   * \brief Returns a map from every let-bound variable to whether its let-bound value is
+   * definitely pure.
+   */
+  std::unordered_map<const VarNode*, bool> GetPurityMap() const {
+    std::unordered_map<const VarNode*, bool> result;
+    for (const auto& kv : var_to_purity_) {
+      result.emplace(kv.first, kv.second.pure_eval);
+    }
+    return result;
+  }
 
-class CalcDep;
-class FindDef : private ExprVisitor {
  private:
-  VarMap<Expr> expr_map_;
+  Purity VisitExpr(const Expr& expr) final {
+    auto it = memo_.find(expr.get());
+    if (it != this->memo_.end()) {
+      return it->second;
+    } else {
+      Purity result = ExprFunctor::VisitExpr(expr);
+      memo_[expr.get()] = result;
+      return result;
+    }
+  }
 
-  void VisitExpr_(const LetNode* l) final {
-    auto pre_visit = [this](const LetNode* op) {
-      ICHECK_EQ(expr_map_.count(op->var), 0);
-      expr_map_[op->var] = op->value;
-      this->VisitExpr(op->value);
-    };
-    auto post_visit = [this](const LetNode* op) {
-      this->VisitExpr(op->body);
-      this->visit_counter_[op] += 1;
-    };
-    ExpandANormalForm(l, pre_visit, post_visit);
+  Purity VisitExpr_(const ConstantNode*) final { return {/*pure_eval=*/true, /*pure_call=*/true}; }
+
+  Purity VisitExpr_(const ConstructorNode*) final {
+    return {/*pure_eval=*/true, /*pure_call=*/true};
+  }
+
+  Purity VisitExpr_(const OpNode* op_node) final {
+    // Primitive operators are pure unless marked as 'stateful'.
+    static OpAttrMap<bool> attr_map = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
+    bool is_statefull = attr_map.count(GetRef<Op>(op_node)) && attr_map[GetRef<Op>(op_node)];

Review comment:
       sp: statefull -> stateful




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9542: [Relay] Prepare DeadCodeElimination for running post LowerTEPass/ManifestAlloc.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9542:
URL: https://github.com/apache/tvm/pull/9542#discussion_r754747059



##########
File path: include/tvm/relay/transform.h
##########
@@ -540,7 +550,7 @@ TVM_DLL Function ToCPS(const Function& f, const IRModule& mod);
 /*!
  * \brief Remove the continuation argument of a CPS function.
  *
- * Note that this only transform the type back into un-CPS form
+ * Note that this only transform the type back into un-CPS formA

Review comment:
       thanks, done!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org