You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/04/14 01:21:39 UTC

[incubator-tvm] branch master updated: add memoized expr translator for use by backend codegen (#5325)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 2c1ca60  add memoized expr translator for use by backend codegen (#5325)
2c1ca60 is described below

commit 2c1ca60ea38587401a20f11c5e64f452b79fa777
Author: masahi <ma...@gmail.com>
AuthorDate: Tue Apr 14 10:21:31 2020 +0900

    add memoized expr translator for use by backend codegen (#5325)
---
 src/relay/backend/compile_engine.cc            | 64 +++++++++-----------------
 src/relay/backend/contrib/codegen_c/codegen.cc | 12 +----
 src/relay/backend/contrib/dnnl/codegen.cc      | 12 +----
 src/relay/backend/graph_runtime_codegen.cc     | 50 ++------------------
 src/relay/backend/interpreter.cc               |  5 --
 src/relay/backend/utils.h                      | 35 ++++++++++++++
 6 files changed, 63 insertions(+), 115 deletions(-)

diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc
index 4ed8fbc..ce0a314 100644
--- a/src/relay/backend/compile_engine.cc
+++ b/src/relay/backend/compile_engine.cc
@@ -21,29 +21,31 @@
  * \file relay/backend/compile_engine.cc
  * \brief Internal compialtion engine.
  */
+#include "compile_engine.h"
+
+#include <topi/tags.h>
+#include <tvm/driver/driver_api.h>
 #include <tvm/ir/type_functor.h>
-#include <tvm/te/schedule.h>
-#include <tvm/te/operation.h>
-#include <tvm/te/schedule_pass.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/container.h>
-#include <tvm/relay/attrs/device_copy.h>
 #include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/op_attr_types.h>
-#include <tvm/driver/driver_api.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
 
-#include <topi/tags.h>
-#include <utility>
+#include <functional>
 #include <limits>
 #include <mutex>
-#include <functional>
-#include <vector>
 #include <unordered_map>
+#include <utility>
+#include <vector>
 
-#include "compile_engine.h"
+#include "utils.h"
 
 namespace tvm {
 namespace relay {
@@ -111,8 +113,7 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
 
 // The getter to get schedule from compile engine.
 // Get schedule from functor.
-class ScheduleGetter :
-      public ExprFunctor<Array<te::Tensor>(const Expr&)> {
+class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
  public:
   explicit ScheduleGetter(Target target)
       : target_(target), device_copy_op_(Op::Get("device_copy")) {}
@@ -179,17 +180,6 @@ class ScheduleGetter :
     return CachedFunc(cache_node);
   }
 
-  Array<te::Tensor> VisitExpr(const Expr& expr) {
-    auto it = memo_.find(expr);
-    if (it != memo_.end()) {
-      return it->second;
-    } else {
-      Array<te::Tensor> res = ExprFunctor::VisitExpr(expr);
-      memo_[expr] = res;
-      return res;
-    }
-  }
-
   Array<te::Tensor> VisitExpr_(const VarNode* op) final {
     LOG(FATAL) << "Free variable " << op->name_hint();
     return {};
@@ -327,7 +317,6 @@ class ScheduleGetter :
   int master_op_pattern_{0};
   OpImplementation master_implementation_;
   std::ostringstream readable_name_stream_;
-  std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_;
   Array<te::Operation> scalars_;
   // Cache device copy op for equivalence checking to reduce registry lookup
   // overhead for each invocation of call node when retrieving schedules.
@@ -335,7 +324,7 @@ class ScheduleGetter :
 };
 
 // Creates shape function from functor.
-class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
+class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
  public:
   MakeShapeFunc() {}
 
@@ -422,19 +411,14 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
     return std::make_pair(schedule, cfunc);
   }
 
-  Array<te::Tensor> VisitExpr(const Expr& expr) {
-    auto it = memo_.find(expr);
-    if (it != memo_.end()) {
-      return it->second;
-    } else {
-      Array<te::Tensor> res = ExprFunctor::VisitExpr(expr);
-      if (expr.as<VarNode>() == nullptr) {
-        // Do not memoize vars because shape functions could use either the data
-        // or the shape of a var each time.
-        memo_[expr] = res;
-      }
-      return res;
+  Array<te::Tensor> VisitExpr(const Expr& expr) final {
+    if (expr.as<VarNode>()) {
+      // Do not memoize vars because shape functions could use either the data
+      // or the shape of a var each time.
+      return ExprFunctor::VisitExpr(expr);
     }
+    // For other case, do memoized visit
+    return backend::MemoizedExprTranslator<Array<te::Tensor>>::VisitExpr(expr);
   }
 
   Array<te::Tensor> VisitExpr_(const VarNode* var_node) final {
@@ -577,8 +561,6 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
   std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> param_data_;
   /*! \brief Map from parameter to list of shape placeholder */
   std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> param_shapes_;
-  /*! \brief Memoized visit result */
-  std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_;
   /*! \brief Stack of data dependencies for shape function */
   std::vector<bool> data_dependants_;
   /*! \brief Scalars used in the shape function */
diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc
index fc93b73..0b3510c 100644
--- a/src/relay/backend/contrib/codegen_c/codegen.cc
+++ b/src/relay/backend/contrib/codegen_c/codegen.cc
@@ -40,18 +40,10 @@ using namespace backend;
  * purpose. Only several binary options are covered. Users
  * may need to extend them to cover more operators.
  */
-class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>,
-                 public CodegenCBase {
+class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public CodegenCBase {
  public:
   explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; }
 
-  std::vector<Output> VisitExpr(const Expr& expr) final {
-    if (visited_.count(expr)) return visited_.at(expr);
-    std::vector<Output> output = ExprFunctor::VisitExpr(expr);
-    visited_[expr] = output;
-    return output;
-  }
-
   std::vector<Output> VisitExprDefault_(const Object* op) final {
     LOG(FATAL) << "C codegen doesn't support: " << op->GetTypeKey();
     return {};
@@ -208,8 +200,6 @@ class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>,
   std::vector<std::string> func_decl_;
   /*! \brief The declaration statements of buffers. */
   std::vector<std::string> buf_decl_;
-  /*! \brief The name and index pairs for output. */
-  std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
 };
 
 class CSourceCodegen : public CSourceModuleCodegenBase {
diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc
index 48652fc..26bc878 100644
--- a/src/relay/backend/contrib/dnnl/codegen.cc
+++ b/src/relay/backend/contrib/dnnl/codegen.cc
@@ -128,18 +128,10 @@ std::vector<std::string> Add(const CallNode* call) {
 
 // TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
 // all utilities and make a base class for users to implement.
-class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>,
-                    public CodegenCBase {
+class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public CodegenCBase {
  public:
   explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; }
 
-  std::vector<Output> VisitExpr(const Expr& expr) final {
-    if (visited_.count(expr)) return visited_.at(expr);
-    std::vector<Output> output = ExprFunctor::VisitExpr(expr);
-    visited_[expr] = output;
-    return output;
-  }
-
   std::vector<Output> VisitExprDefault_(const Object* op) final {
     LOG(FATAL) << "DNNL codegen doesn't support: " << op->GetTypeKey();
     return {};
@@ -343,8 +335,6 @@ class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>,
   std::vector<std::string> ext_func_body;
   /*! \brief The declaration of intermeidate buffers. */
   std::vector<std::string> buf_decl_;
-  /*! \brief The cached expressions. */
-  std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
 };
 
 /*!
diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc
index 4279db0..7b686c7 100644
--- a/src/relay/backend/graph_runtime_codegen.cc
+++ b/src/relay/backend/graph_runtime_codegen.cc
@@ -28,13 +28,12 @@
 #include <tvm/relay/expr_functor.h>
 #include <tvm/runtime/device_api.h>
 
-
 #include <list>
 #include <string>
 #include <vector>
 
-#include "utils.h"
 #include "compile_engine.h"
+#include "utils.h"
 
 namespace tvm {
 namespace relay {
@@ -190,11 +189,9 @@ class GraphOpNode : public GraphNode {
 };
 
 /*! \brief Code generator for graph runtime */
-class GraphRuntimeCodegen
-    : public ::tvm::relay::ExprFunctor<std::vector<GraphNodeRef>(const Expr&)> {
+class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
  public:
-  GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets)
-      : mod_(mod) {
+  GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
     compile_engine_ = CompileEngine::Global();
     targets_ = targets;
   }
@@ -313,47 +310,6 @@ class GraphRuntimeCodegen
     return {GraphNodeRef(node_id, 0)};
   }
 
-  /*! \brief Visitors */
-  std::unordered_map<Expr, std::vector<GraphNodeRef>, ObjectHash, ObjectEqual> visitor_cache_;
-
-  std::vector<GraphNodeRef> VisitExpr(const Expr& expr) override {
-    if (visitor_cache_.count(expr)) return visitor_cache_.at(expr);
-    std::vector<GraphNodeRef> res;
-    if (expr.as<ConstantNode>()) {
-      res = VisitExpr_(expr.as<ConstantNode>());
-    } else if (expr.as<TupleNode>()) {
-      res = VisitExpr_(expr.as<TupleNode>());
-    } else if (expr.as<VarNode>()) {
-      res = VisitExpr_(expr.as<VarNode>());
-    } else if (expr.as<GlobalVarNode>()) {
-      res = VisitExpr_(expr.as<GlobalVarNode>());
-    } else if (expr.as<FunctionNode>()) {
-      res = VisitExpr_(expr.as<FunctionNode>());
-    } else if (expr.as<CallNode>()) {
-      res = VisitExpr_(expr.as<CallNode>());
-    } else if (expr.as<LetNode>()) {
-      res = VisitExpr_(expr.as<LetNode>());
-    } else if (expr.as<IfNode>()) {
-      res = VisitExpr_(expr.as<IfNode>());
-    } else if (expr.as<OpNode>()) {
-      res = VisitExpr_(expr.as<OpNode>());
-    } else if (expr.as<TupleGetItemNode>()) {
-      res = VisitExpr_(expr.as<TupleGetItemNode>());
-    } else if (expr.as<RefCreateNode>()) {
-      res = VisitExpr_(expr.as<RefCreateNode>());
-    } else if (expr.as<RefReadNode>()) {
-      res = VisitExpr_(expr.as<RefReadNode>());
-    } else if (expr.as<RefWriteNode>()) {
-      res = VisitExpr_(expr.as<RefWriteNode>());
-    } else if (expr.as<ConstructorNode>()) {
-      res = VisitExpr_(expr.as<ConstructorNode>());
-    } else if (expr.as<MatchNode>()) {
-      res = VisitExpr_(expr.as<MatchNode>());
-    }
-    visitor_cache_[expr] = res;
-    return res;
-  }
-
   std::vector<GraphNodeRef> VisitExpr_(const VarNode* op) override {
     Expr expr = GetRef<Expr>(op);
     return var_map_[expr.get()];
diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc
index 631f2d4..465f788 100644
--- a/src/relay/backend/interpreter.cc
+++ b/src/relay/backend/interpreter.cc
@@ -244,11 +244,6 @@ class Interpreter :
     return VisitExpr(expr);
   }
 
-  ObjectRef VisitExpr(const Expr& expr) final {
-    auto ret = ExprFunctor<ObjectRef(const Expr& n)>::VisitExpr(expr);
-    return ret;
-  }
-
   ObjectRef VisitExpr_(const VarNode* var_node) final {
     return Lookup(GetRef<Var>(var_node));
   }
diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h
index a96ffe4..65e6ae9 100644
--- a/src/relay/backend/utils.h
+++ b/src/relay/backend/utils.h
@@ -27,6 +27,7 @@
 #include <dmlc/json.h>
 #include <tvm/driver/driver_api.h>
 #include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
 #include <tvm/relay/type.h>
 #include <tvm/target/codegen.h>
@@ -42,6 +43,40 @@
 namespace tvm {
 namespace relay {
 namespace backend {
+
+/*!
+ * \brief A simple wrapper around ExprFunctor for a single argument case.
+ *  The result of visit is memoized.
+ */
+template <typename OutputType>
+class MemoizedExprTranslator : public ::tvm::relay::ExprFunctor<OutputType(const Expr&)> {
+  using BaseFunctor = ::tvm::relay::ExprFunctor<OutputType(const Expr&)>;
+
+ public:
+  /*! \brief virtual destructor */
+  virtual ~MemoizedExprTranslator() {}
+
+  /*!
+   * \brief The memoized call.
+   * \param n The expression node.
+   * \return The result of the call
+   */
+  virtual OutputType VisitExpr(const Expr& n) {
+    CHECK(n.defined());
+    auto it = memo_.find(n);
+    if (it != memo_.end()) {
+      return it->second;
+    }
+    auto res = BaseFunctor::VisitExpr(n);
+    memo_[n] = res;
+    return res;
+  }
+
+ protected:
+  /*! \brief Internal map used for memoization. */
+  std::unordered_map<Expr, OutputType, ObjectHash, ObjectEqual> memo_;
+};
+
 /*!
  * \brief Get the Packed Func
  *