You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/03/18 03:01:23 UTC

[incubator-tvm] branch master updated: Replace UseDefaultCompiler with GetAttr (#5088)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 06bbc7c  Replace UseDefaultCompiler with GetAttr (#5088)
06bbc7c is described below

commit 06bbc7c9e4941713b4012d344b5424d4a32a9228
Author: Zhi <51...@users.noreply.github.com>
AuthorDate: Tue Mar 17 20:01:15 2020 -0700

    Replace UseDefaultCompiler with GetAttr (#5088)
---
 include/tvm/relay/function.h               |  9 ---------
 src/relay/backend/compile_engine.cc        |  4 ++--
 src/relay/backend/graph_runtime_codegen.cc |  5 +++--
 src/relay/backend/vm/compiler.cc           |  4 ++--
 src/relay/backend/vm/inline_primitives.cc  | 10 +++++-----
 src/relay/backend/vm/lambda_lift.cc        | 10 +++++-----
 src/relay/ir/function.cc                   |  5 -----
 src/relay/ir/transform.cc                  |  2 +-
 src/relay/transforms/inline.cc             | 10 +++++-----
 src/relay/transforms/to_a_normal_form.cc   |  2 +-
 10 files changed, 24 insertions(+), 37 deletions(-)

diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h
index f7514c7..5c5bd26 100644
--- a/include/tvm/relay/function.h
+++ b/include/tvm/relay/function.h
@@ -76,15 +76,6 @@ class FunctionNode : public BaseFuncNode {
    */
   TVM_DLL FuncType func_type_annotation() const;
 
-  /*!
-   * \brief Check whether the function should use the TVM default compiler to build, or
-   * use other compilers.
-   *
-   * \return Whether the function will be compiled using the default compiler
-   * (e.g. those are used in the TVM stack).
-   */
-  bool UseDefaultCompiler() const;
-
   static constexpr const char* _type_key = "relay.Function";
   TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
 };
diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc
index ccbe4df..1237c56 100644
--- a/src/relay/backend/compile_engine.cc
+++ b/src/relay/backend/compile_engine.cc
@@ -616,7 +616,7 @@ class CompileEngineImpl : public CompileEngineNode {
     for (const auto& it : cache_) {
       auto src_func = it.first->source_func;
       CHECK(src_func.defined());
-      if (!src_func->UseDefaultCompiler()) {
+      if (src_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
         auto code_gen = src_func->GetAttr<tir::StringImm>(attr::kCompiler);
         CHECK(code_gen.defined()) << "No external codegen is set";
         if (ext_mods.find(code_gen->value) == ext_mods.end()) {
@@ -690,7 +690,7 @@ class CompileEngineImpl : public CompileEngineNode {
     }
     // No need to lower external functions for now. We will invoke the external
     // codegen tool once and lower all functions together.
-    if (!key->source_func->UseDefaultCompiler()) {
+    if (key->source_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
       auto cache_node = make_object<CachedFuncNode>();
       const auto name_node =
           key->source_func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc
index 032ebcd..0587cd2 100644
--- a/src/relay/backend/graph_runtime_codegen.cc
+++ b/src/relay/backend/graph_runtime_codegen.cc
@@ -424,7 +424,7 @@ class GraphRuntimeCodegen
     auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
     Target target;
     // Handle external function
-    if (!func->UseDefaultCompiler()) {
+    if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
       target = tvm::target::ext_dev();
       CCacheKey key = (*pf0)(func, target);
       CachedFunc ext_func = (*pf1)(compile_engine_, key);
@@ -490,7 +490,8 @@ class GraphRuntimeCodegen
     return {};
   }
   std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override {
-    CHECK(!op->UseDefaultCompiler()) << "Only functions supported by custom codegen";
+    CHECK(op->GetAttr<tir::StringImm>(attr::kCompiler).defined())
+        << "Only functions supported by custom codegen";
     return {};
   }
   std::vector<GraphNodeRef> VisitExpr_(const RefCreateNode* op) override {
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index caf429a..2fc6567 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -471,7 +471,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
 
     Target target;
 
-    if (!func->UseDefaultCompiler()) {
+    if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
       target = tvm::target::ext_dev();
     } else {
       // Next generate the invoke instruction.
@@ -489,7 +489,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
     auto cfunc = engine_->Lower(key);
 
     auto op_index = -1;
-    if (!func->UseDefaultCompiler()) {
+    if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
       op_index = context_->cached_funcs.size();
       context_->cached_funcs.push_back(cfunc);
     } else {
diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc
index 0eb6c1a..8327a6b 100644
--- a/src/relay/backend/vm/inline_primitives.cc
+++ b/src/relay/backend/vm/inline_primitives.cc
@@ -122,17 +122,17 @@ struct PrimitiveInliner : ExprMutator {
       auto global = pair.first;
       auto base_func = pair.second;
       if (auto* n = base_func.as<FunctionNode>()) {
-        if (!n->UseDefaultCompiler()) continue;
+        if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
         auto func = GetRef<Function>(n);
 
         DLOG(INFO) << "Before inlining primitives: " << global
                    << std::endl << AsText(func, false);
 
         func = Function(func->params,
-                                  VisitExpr(func->body),
-                                  func->ret_type,
-                                  func->type_params,
-                                  func->attrs);
+                        VisitExpr(func->body),
+                        func->ret_type,
+                        func->type_params,
+                        func->attrs);
         module_->Add(global, func, true);
 
         DLOG(INFO) << "After inlining primitives: " << global
diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc
index 987fdcb..fd8c351 100644
--- a/src/relay/backend/vm/lambda_lift.cc
+++ b/src/relay/backend/vm/lambda_lift.cc
@@ -187,13 +187,13 @@ class LambdaLifter : public ExprMutator {
     auto glob_funcs = module_->functions;
     for (auto pair : glob_funcs) {
       if (auto* n = pair.second.as<FunctionNode>()) {
-        if (!n->UseDefaultCompiler()) continue;
+        if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
         auto func = GetRef<Function>(n);
         func = Function(func->params,
-                                  VisitExpr(func->body),
-                                  func->ret_type,
-                                  func->type_params,
-                                  func->attrs);
+                        VisitExpr(func->body),
+                        func->ret_type,
+                        func->type_params,
+                        func->attrs);
         module_->Add(pair.first, func, true);
       }
     }
diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc
index d371edb..b251645 100644
--- a/src/relay/ir/function.cc
+++ b/src/relay/ir/function.cc
@@ -55,11 +55,6 @@ FuncType FunctionNode::func_type_annotation() const {
   return FuncType(param_types, ret_type, this->type_params, {});
 }
 
-bool FunctionNode::UseDefaultCompiler() const {
-  tir::StringImm val = this->GetAttr<tir::StringImm>(attr::kCompiler);
-  return !val.defined() || val->value == "default";
-}
-
 TVM_REGISTER_NODE_TYPE(FunctionNode);
 
 TVM_REGISTER_GLOBAL("relay.ir.Function")
diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc
index 919b066..59c1750 100644
--- a/src/relay/ir/transform.cc
+++ b/src/relay/ir/transform.cc
@@ -140,7 +140,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
 
 bool FunctionPassNode::SkipFunction(const Function& func) const {
   return func->GetAttr<Integer>(attr::kSkipOptimization, 0)->value != 0 ||
-    !(func->UseDefaultCompiler());
+    (func->GetAttr<tir::StringImm>(attr::kCompiler).defined());
 }
 
 Pass CreateFunctionPass(
diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc
index 5f26d67..9e118ba 100644
--- a/src/relay/transforms/inline.cc
+++ b/src/relay/transforms/inline.cc
@@ -125,13 +125,13 @@ class Inliner : ExprMutator {
     CHECK(fn) << "Expected to work on a Relay function.";
 
     auto func = Function(fn->params,
-                                   fn->body,
-                                   fn->ret_type,
-                                   fn->type_params,
-                                   fn->attrs);
+                         fn->body,
+                         fn->ret_type,
+                         fn->type_params,
+                         fn->attrs);
     // Inline the function body to the caller if this function uses default
     // compiler, i.e. no external codegen is needed.
-    if (func->UseDefaultCompiler()) {
+    if (!func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
       CHECK_EQ(func->params.size(), args.size())
           << "Mismatch found in the number of parameters and call args";
       // Bind the parameters with call args.
diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc
index 327eb62..e4722e2 100644
--- a/src/relay/transforms/to_a_normal_form.cc
+++ b/src/relay/transforms/to_a_normal_form.cc
@@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) {
   for (const auto& it : funcs) {
     CHECK_EQ(FreeVars(it.second).size(), 0);
     if (const auto* n = it.second.as<FunctionNode>()) {
-      if (!n->UseDefaultCompiler()) continue;
+      if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
     }
     Expr ret =
       TransformF([&](const Expr& e) {