You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/03/18 03:01:23 UTC
[incubator-tvm] branch master updated: Replace UseDefaultCompiler
with GetAttr (#5088)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 06bbc7c Replace UseDefaultCompiler with GetAttr (#5088)
06bbc7c is described below
commit 06bbc7c9e4941713b4012d344b5424d4a32a9228
Author: Zhi <51...@users.noreply.github.com>
AuthorDate: Tue Mar 17 20:01:15 2020 -0700
Replace UseDefaultCompiler with GetAttr (#5088)
---
include/tvm/relay/function.h | 9 ---------
src/relay/backend/compile_engine.cc | 4 ++--
src/relay/backend/graph_runtime_codegen.cc | 5 +++--
src/relay/backend/vm/compiler.cc | 4 ++--
src/relay/backend/vm/inline_primitives.cc | 10 +++++-----
src/relay/backend/vm/lambda_lift.cc | 10 +++++-----
src/relay/ir/function.cc | 5 -----
src/relay/ir/transform.cc | 2 +-
src/relay/transforms/inline.cc | 10 +++++-----
src/relay/transforms/to_a_normal_form.cc | 2 +-
10 files changed, 24 insertions(+), 37 deletions(-)
diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h
index f7514c7..5c5bd26 100644
--- a/include/tvm/relay/function.h
+++ b/include/tvm/relay/function.h
@@ -76,15 +76,6 @@ class FunctionNode : public BaseFuncNode {
*/
TVM_DLL FuncType func_type_annotation() const;
- /*!
- * \brief Check whether the function should use the TVM default compiler to build, or
- * use other compilers.
- *
- * \return Whether the function will be compiled using the default compiler
- * (e.g. those are used in the TVM stack).
- */
- bool UseDefaultCompiler() const;
-
static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
};
diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc
index ccbe4df..1237c56 100644
--- a/src/relay/backend/compile_engine.cc
+++ b/src/relay/backend/compile_engine.cc
@@ -616,7 +616,7 @@ class CompileEngineImpl : public CompileEngineNode {
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
CHECK(src_func.defined());
- if (!src_func->UseDefaultCompiler()) {
+ if (src_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
auto code_gen = src_func->GetAttr<tir::StringImm>(attr::kCompiler);
CHECK(code_gen.defined()) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) {
@@ -690,7 +690,7 @@ class CompileEngineImpl : public CompileEngineNode {
}
// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
- if (!key->source_func->UseDefaultCompiler()) {
+ if (key->source_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
auto cache_node = make_object<CachedFuncNode>();
const auto name_node =
key->source_func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc
index 032ebcd..0587cd2 100644
--- a/src/relay/backend/graph_runtime_codegen.cc
+++ b/src/relay/backend/graph_runtime_codegen.cc
@@ -424,7 +424,7 @@ class GraphRuntimeCodegen
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
Target target;
// Handle external function
- if (!func->UseDefaultCompiler()) {
+ if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
target = tvm::target::ext_dev();
CCacheKey key = (*pf0)(func, target);
CachedFunc ext_func = (*pf1)(compile_engine_, key);
@@ -490,7 +490,8 @@ class GraphRuntimeCodegen
return {};
}
std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override {
- CHECK(!op->UseDefaultCompiler()) << "Only functions supported by custom codegen";
+ CHECK(op->GetAttr<tir::StringImm>(attr::kCompiler).defined())
+ << "Only functions supported by custom codegen";
return {};
}
std::vector<GraphNodeRef> VisitExpr_(const RefCreateNode* op) override {
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index caf429a..2fc6567 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -471,7 +471,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
Target target;
- if (!func->UseDefaultCompiler()) {
+ if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
target = tvm::target::ext_dev();
} else {
// Next generate the invoke instruction.
@@ -489,7 +489,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
auto cfunc = engine_->Lower(key);
auto op_index = -1;
- if (!func->UseDefaultCompiler()) {
+ if (func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
} else {
diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc
index 0eb6c1a..8327a6b 100644
--- a/src/relay/backend/vm/inline_primitives.cc
+++ b/src/relay/backend/vm/inline_primitives.cc
@@ -122,17 +122,17 @@ struct PrimitiveInliner : ExprMutator {
auto global = pair.first;
auto base_func = pair.second;
if (auto* n = base_func.as<FunctionNode>()) {
- if (!n->UseDefaultCompiler()) continue;
+ if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);
DLOG(INFO) << "Before inlining primitives: " << global
<< std::endl << AsText(func, false);
func = Function(func->params,
- VisitExpr(func->body),
- func->ret_type,
- func->type_params,
- func->attrs);
+ VisitExpr(func->body),
+ func->ret_type,
+ func->type_params,
+ func->attrs);
module_->Add(global, func, true);
DLOG(INFO) << "After inlining primitives: " << global
diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc
index 987fdcb..fd8c351 100644
--- a/src/relay/backend/vm/lambda_lift.cc
+++ b/src/relay/backend/vm/lambda_lift.cc
@@ -187,13 +187,13 @@ class LambdaLifter : public ExprMutator {
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
if (auto* n = pair.second.as<FunctionNode>()) {
- if (!n->UseDefaultCompiler()) continue;
+ if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);
func = Function(func->params,
- VisitExpr(func->body),
- func->ret_type,
- func->type_params,
- func->attrs);
+ VisitExpr(func->body),
+ func->ret_type,
+ func->type_params,
+ func->attrs);
module_->Add(pair.first, func, true);
}
}
diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc
index d371edb..b251645 100644
--- a/src/relay/ir/function.cc
+++ b/src/relay/ir/function.cc
@@ -55,11 +55,6 @@ FuncType FunctionNode::func_type_annotation() const {
return FuncType(param_types, ret_type, this->type_params, {});
}
-bool FunctionNode::UseDefaultCompiler() const {
- tir::StringImm val = this->GetAttr<tir::StringImm>(attr::kCompiler);
- return !val.defined() || val->value == "default";
-}
-
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_GLOBAL("relay.ir.Function")
diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc
index 919b066..59c1750 100644
--- a/src/relay/ir/transform.cc
+++ b/src/relay/ir/transform.cc
@@ -140,7 +140,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
bool FunctionPassNode::SkipFunction(const Function& func) const {
return func->GetAttr<Integer>(attr::kSkipOptimization, 0)->value != 0 ||
- !(func->UseDefaultCompiler());
+ (func->GetAttr<tir::StringImm>(attr::kCompiler).defined());
}
Pass CreateFunctionPass(
diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc
index 5f26d67..9e118ba 100644
--- a/src/relay/transforms/inline.cc
+++ b/src/relay/transforms/inline.cc
@@ -125,13 +125,13 @@ class Inliner : ExprMutator {
CHECK(fn) << "Expected to work on a Relay function.";
auto func = Function(fn->params,
- fn->body,
- fn->ret_type,
- fn->type_params,
- fn->attrs);
+ fn->body,
+ fn->ret_type,
+ fn->type_params,
+ fn->attrs);
// Inline the function body to the caller if this function uses default
// compiler, i.e. no external codegen is needed.
- if (func->UseDefaultCompiler()) {
+ if (!func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
CHECK_EQ(func->params.size(), args.size())
<< "Mismatch found in the number of parameters and call args";
// Bind the parameters with call args.
diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc
index 327eb62..e4722e2 100644
--- a/src/relay/transforms/to_a_normal_form.cc
+++ b/src/relay/transforms/to_a_normal_form.cc
@@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) {
for (const auto& it : funcs) {
CHECK_EQ(FreeVars(it.second).size(), 0);
if (const auto* n = it.second.as<FunctionNode>()) {
- if (!n->UseDefaultCompiler()) continue;
+ if (n->GetAttr<tir::StringImm>(attr::kCompiler).defined()) continue;
}
Expr ret =
TransformF([&](const Expr& e) {