You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/04/14 01:21:39 UTC
[incubator-tvm] branch master updated: add memoized expr translator
for use by backend codegen (#5325)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 2c1ca60 add memoized expr translator for use by backend codegen (#5325)
2c1ca60 is described below
commit 2c1ca60ea38587401a20f11c5e64f452b79fa777
Author: masahi <ma...@gmail.com>
AuthorDate: Tue Apr 14 10:21:31 2020 +0900
add memoized expr translator for use by backend codegen (#5325)
---
src/relay/backend/compile_engine.cc | 64 +++++++++-----------------
src/relay/backend/contrib/codegen_c/codegen.cc | 12 +----
src/relay/backend/contrib/dnnl/codegen.cc | 12 +----
src/relay/backend/graph_runtime_codegen.cc | 50 ++------------------
src/relay/backend/interpreter.cc | 5 --
src/relay/backend/utils.h | 35 ++++++++++++++
6 files changed, 63 insertions(+), 115 deletions(-)
diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc
index 4ed8fbc..ce0a314 100644
--- a/src/relay/backend/compile_engine.cc
+++ b/src/relay/backend/compile_engine.cc
@@ -21,29 +21,31 @@
* \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine.
*/
+#include "compile_engine.h"
+
+#include <topi/tags.h>
+#include <tvm/driver/driver_api.h>
#include <tvm/ir/type_functor.h>
-#include <tvm/te/schedule.h>
-#include <tvm/te/operation.h>
-#include <tvm/te/schedule_pass.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/runtime/container.h>
-#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
-#include <tvm/driver/driver_api.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
-#include <topi/tags.h>
-#include <utility>
+#include <functional>
#include <limits>
#include <mutex>
-#include <functional>
-#include <vector>
#include <unordered_map>
+#include <utility>
+#include <vector>
-#include "compile_engine.h"
+#include "utils.h"
namespace tvm {
namespace relay {
@@ -111,8 +113,7 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
// The getter to get schedule from compile engine.
// Get schedule from functor.
-class ScheduleGetter :
- public ExprFunctor<Array<te::Tensor>(const Expr&)> {
+class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public:
explicit ScheduleGetter(Target target)
: target_(target), device_copy_op_(Op::Get("device_copy")) {}
@@ -179,17 +180,6 @@ class ScheduleGetter :
return CachedFunc(cache_node);
}
- Array<te::Tensor> VisitExpr(const Expr& expr) {
- auto it = memo_.find(expr);
- if (it != memo_.end()) {
- return it->second;
- } else {
- Array<te::Tensor> res = ExprFunctor::VisitExpr(expr);
- memo_[expr] = res;
- return res;
- }
- }
-
Array<te::Tensor> VisitExpr_(const VarNode* op) final {
LOG(FATAL) << "Free variable " << op->name_hint();
return {};
@@ -327,7 +317,6 @@ class ScheduleGetter :
int master_op_pattern_{0};
OpImplementation master_implementation_;
std::ostringstream readable_name_stream_;
- std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_;
Array<te::Operation> scalars_;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
@@ -335,7 +324,7 @@ class ScheduleGetter :
};
// Creates shape function from functor.
-class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
+class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public:
MakeShapeFunc() {}
@@ -422,19 +411,14 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
return std::make_pair(schedule, cfunc);
}
- Array<te::Tensor> VisitExpr(const Expr& expr) {
- auto it = memo_.find(expr);
- if (it != memo_.end()) {
- return it->second;
- } else {
- Array<te::Tensor> res = ExprFunctor::VisitExpr(expr);
- if (expr.as<VarNode>() == nullptr) {
- // Do not memoize vars because shape functions could use either the data
- // or the shape of a var each time.
- memo_[expr] = res;
- }
- return res;
+ Array<te::Tensor> VisitExpr(const Expr& expr) final {
+ if (expr.as<VarNode>()) {
+ // Do not memoize vars because shape functions could use either the data
+ // or the shape of a var each time.
+ return ExprFunctor::VisitExpr(expr);
}
+ // For other case, do memoized visit
+ return backend::MemoizedExprTranslator<Array<te::Tensor>>::VisitExpr(expr);
}
Array<te::Tensor> VisitExpr_(const VarNode* var_node) final {
@@ -577,8 +561,6 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> param_data_;
/*! \brief Map from parameter to list of shape placeholder */
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> param_shapes_;
- /*! \brief Memoized visit result */
- std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_;
/*! \brief Stack of data dependencies for shape function */
std::vector<bool> data_dependants_;
/*! \brief Scalars used in the shape function */
diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc
index fc93b73..0b3510c 100644
--- a/src/relay/backend/contrib/codegen_c/codegen.cc
+++ b/src/relay/backend/contrib/codegen_c/codegen.cc
@@ -40,18 +40,10 @@ using namespace backend;
* purpose. Only several binary options are covered. Users
* may need to extend them to cover more operators.
*/
-class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>,
- public CodegenCBase {
+class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public CodegenCBase {
public:
explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; }
- std::vector<Output> VisitExpr(const Expr& expr) final {
- if (visited_.count(expr)) return visited_.at(expr);
- std::vector<Output> output = ExprFunctor::VisitExpr(expr);
- visited_[expr] = output;
- return output;
- }
-
std::vector<Output> VisitExprDefault_(const Object* op) final {
LOG(FATAL) << "C codegen doesn't support: " << op->GetTypeKey();
return {};
@@ -208,8 +200,6 @@ class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>,
std::vector<std::string> func_decl_;
/*! \brief The declaration statements of buffers. */
std::vector<std::string> buf_decl_;
- /*! \brief The name and index pairs for output. */
- std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
};
class CSourceCodegen : public CSourceModuleCodegenBase {
diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc
index 48652fc..26bc878 100644
--- a/src/relay/backend/contrib/dnnl/codegen.cc
+++ b/src/relay/backend/contrib/dnnl/codegen.cc
@@ -128,18 +128,10 @@ std::vector<std::string> Add(const CallNode* call) {
// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
// all utilities and make a base class for users to implement.
-class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>,
- public CodegenCBase {
+class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public CodegenCBase {
public:
explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; }
- std::vector<Output> VisitExpr(const Expr& expr) final {
- if (visited_.count(expr)) return visited_.at(expr);
- std::vector<Output> output = ExprFunctor::VisitExpr(expr);
- visited_[expr] = output;
- return output;
- }
-
std::vector<Output> VisitExprDefault_(const Object* op) final {
LOG(FATAL) << "DNNL codegen doesn't support: " << op->GetTypeKey();
return {};
@@ -343,8 +335,6 @@ class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>,
std::vector<std::string> ext_func_body;
/*! \brief The declaration of intermeidate buffers. */
std::vector<std::string> buf_decl_;
- /*! \brief The cached expressions. */
- std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
};
/*!
diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc
index 4279db0..7b686c7 100644
--- a/src/relay/backend/graph_runtime_codegen.cc
+++ b/src/relay/backend/graph_runtime_codegen.cc
@@ -28,13 +28,12 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>
-
#include <list>
#include <string>
#include <vector>
-#include "utils.h"
#include "compile_engine.h"
+#include "utils.h"
namespace tvm {
namespace relay {
@@ -190,11 +189,9 @@ class GraphOpNode : public GraphNode {
};
/*! \brief Code generator for graph runtime */
-class GraphRuntimeCodegen
- : public ::tvm::relay::ExprFunctor<std::vector<GraphNodeRef>(const Expr&)> {
+class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
public:
- GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets)
- : mod_(mod) {
+ GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
compile_engine_ = CompileEngine::Global();
targets_ = targets;
}
@@ -313,47 +310,6 @@ class GraphRuntimeCodegen
return {GraphNodeRef(node_id, 0)};
}
- /*! \brief Visitors */
- std::unordered_map<Expr, std::vector<GraphNodeRef>, ObjectHash, ObjectEqual> visitor_cache_;
-
- std::vector<GraphNodeRef> VisitExpr(const Expr& expr) override {
- if (visitor_cache_.count(expr)) return visitor_cache_.at(expr);
- std::vector<GraphNodeRef> res;
- if (expr.as<ConstantNode>()) {
- res = VisitExpr_(expr.as<ConstantNode>());
- } else if (expr.as<TupleNode>()) {
- res = VisitExpr_(expr.as<TupleNode>());
- } else if (expr.as<VarNode>()) {
- res = VisitExpr_(expr.as<VarNode>());
- } else if (expr.as<GlobalVarNode>()) {
- res = VisitExpr_(expr.as<GlobalVarNode>());
- } else if (expr.as<FunctionNode>()) {
- res = VisitExpr_(expr.as<FunctionNode>());
- } else if (expr.as<CallNode>()) {
- res = VisitExpr_(expr.as<CallNode>());
- } else if (expr.as<LetNode>()) {
- res = VisitExpr_(expr.as<LetNode>());
- } else if (expr.as<IfNode>()) {
- res = VisitExpr_(expr.as<IfNode>());
- } else if (expr.as<OpNode>()) {
- res = VisitExpr_(expr.as<OpNode>());
- } else if (expr.as<TupleGetItemNode>()) {
- res = VisitExpr_(expr.as<TupleGetItemNode>());
- } else if (expr.as<RefCreateNode>()) {
- res = VisitExpr_(expr.as<RefCreateNode>());
- } else if (expr.as<RefReadNode>()) {
- res = VisitExpr_(expr.as<RefReadNode>());
- } else if (expr.as<RefWriteNode>()) {
- res = VisitExpr_(expr.as<RefWriteNode>());
- } else if (expr.as<ConstructorNode>()) {
- res = VisitExpr_(expr.as<ConstructorNode>());
- } else if (expr.as<MatchNode>()) {
- res = VisitExpr_(expr.as<MatchNode>());
- }
- visitor_cache_[expr] = res;
- return res;
- }
-
std::vector<GraphNodeRef> VisitExpr_(const VarNode* op) override {
Expr expr = GetRef<Expr>(op);
return var_map_[expr.get()];
diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc
index 631f2d4..465f788 100644
--- a/src/relay/backend/interpreter.cc
+++ b/src/relay/backend/interpreter.cc
@@ -244,11 +244,6 @@ class Interpreter :
return VisitExpr(expr);
}
- ObjectRef VisitExpr(const Expr& expr) final {
- auto ret = ExprFunctor<ObjectRef(const Expr& n)>::VisitExpr(expr);
- return ret;
- }
-
ObjectRef VisitExpr_(const VarNode* var_node) final {
return Lookup(GetRef<Var>(var_node));
}
diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h
index a96ffe4..65e6ae9 100644
--- a/src/relay/backend/utils.h
+++ b/src/relay/backend/utils.h
@@ -27,6 +27,7 @@
#include <dmlc/json.h>
#include <tvm/driver/driver_api.h>
#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/target/codegen.h>
@@ -42,6 +43,40 @@
namespace tvm {
namespace relay {
namespace backend {
+
+/*!
+ * \brief A simple wrapper around ExprFunctor for a single argument case.
+ * The result of visit is memoized.
+ */
+template <typename OutputType>
+class MemoizedExprTranslator : public ::tvm::relay::ExprFunctor<OutputType(const Expr&)> {
+ using BaseFunctor = ::tvm::relay::ExprFunctor<OutputType(const Expr&)>;
+
+ public:
+ /*! \brief virtual destructor */
+ virtual ~MemoizedExprTranslator() {}
+
+ /*!
+ * \brief The memoized call.
+ * \param n The expression node.
+ * \return The result of the call
+ */
+ virtual OutputType VisitExpr(const Expr& n) {
+ CHECK(n.defined());
+ auto it = memo_.find(n);
+ if (it != memo_.end()) {
+ return it->second;
+ }
+ auto res = BaseFunctor::VisitExpr(n);
+ memo_[n] = res;
+ return res;
+ }
+
+ protected:
+ /*! \brief Internal map used for memoization. */
+ std::unordered_map<Expr, OutputType, ObjectHash, ObjectEqual> memo_;
+};
+
/*!
* \brief Get the Packed Func
*