You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/01/16 17:51:26 UTC
[tvm] branch main updated: [TIR] Support Return in TIR (#7084)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 052ad3d [TIR] Support Return in TIR (#7084)
052ad3d is described below
commit 052ad3d92d20abdf221600005c2ccb130e39b6b4
Author: ziheng <zi...@apache.org>
AuthorDate: Sat Jan 16 09:51:04 2021 -0800
[TIR] Support Return in TIR (#7084)
---
include/tvm/tir/builtin.h | 4 +++
include/tvm/tir/op.h | 9 +++++
include/tvm/tir/op_attr_types.h | 6 +++-
python/tvm/tir/__init__.py | 2 +-
python/tvm/tir/op.py | 28 +++++++++++----
src/target/llvm/codegen_llvm.cc | 12 +++++++
src/tir/op/builtin.cc | 4 +++
src/tir/op/op.cc | 4 +++
src/tir/transforms/make_packed_api.cc | 66 ++++++++++++++++++++++++++++++++--
tests/python/unittest/test_tir_base.py | 60 +++++++++++++++++++++++++++++++
10 files changed, 185 insertions(+), 10 deletions(-)
diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index a150595..6a40d86 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -42,6 +42,10 @@ namespace tir {
/*! \brief Collection of builtin intrinsics as ops */
namespace builtin {
/*!
+ * \brief Return value.
+ */
+TVM_DLL const Op& ret();
+/*!
* \brief Reinterpret the value using the target type.
*/
TVM_DLL const Op& reinterpret();
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 4a907fc..b5a62c9 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -71,6 +71,15 @@ TVM_DLL Type GetType(const PrimExpr& expr);
TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type);
/*!
+ * \brief Return the value.
+ *
+ * \param value The returned value.
+ * \param span The location of this operation in the source.
+ * \return The return expression.
+ */
+TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span());
+
+/*!
* Query the maximum possible value of dtype.
* \param dtype The data type.
* \param span The location of this operation in the source.
diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h
index ec7fc17..3dcc4b9 100644
--- a/include/tvm/tir/op_attr_types.h
+++ b/include/tvm/tir/op_attr_types.h
@@ -74,7 +74,11 @@ enum class CallEffectKind : int {
/*!
* \brief Embed opaque information in the Expr, cannot be codegen.
*/
- kEmbedInfo = 5
+ kEmbedInfo = 5,
+ /*!
+ * \brief Function that changes control flow
+ */
+ kControlJump = 6,
};
/*! \brief Use integer to record the kind. */
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 1aac55f..901c89e 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -35,7 +35,7 @@ from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .function import PrimFunc
from .op import call_packed, call_intrin, call_pure_extern, call_extern
-from .op import call_llvm_intrin, call_llvm_pure_intrin, all, any, min_value, max_value, trace
+from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp
from .op import sin, sinh, asin, asinh
from .op import cos, cosh, acos, acosh
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index ca61be4..182264f 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -221,6 +221,22 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
)
+def ret(val):
+ """Create a tir return expression
+
+ Parameters
+ ----------
+ val : Expr
+ The returned tir expression, whose data type is int, float or void pointer.
+
+ Returns
+ -------
+ ret : PrimExpr
+ The return expression
+ """
+ return call_intrin(val.dtype, "tir.ret", val)
+
+
def any(*args, span=None):
"""Create a new experssion of the union of all conditions in the arguments
@@ -241,10 +257,10 @@ def any(*args, span=None):
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
- ret = _ffi_api._OpOr(args[0], args[1], span)
+ val = _ffi_api._OpOr(args[0], args[1], span)
for i in range(2, len(args)):
- ret = _ffi_api._OpOr(ret, args[i], span)
- return ret
+ val = _ffi_api._OpOr(val, args[i], span)
+ return val
def all(*args, span=None):
@@ -268,10 +284,10 @@ def all(*args, span=None):
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
- ret = _ffi_api._OpAnd(args[0], args[1], span)
+ val = _ffi_api._OpAnd(args[0], args[1], span)
for i in range(2, len(args)):
- ret = _ffi_api._OpAnd(ret, args[i], span)
- return ret
+ val = _ffi_api._OpAnd(val, args[i], span)
+ return val
@tvm._ffi.register_func("tvm.default_trace_action")
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 70f094a..34f3897 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -927,6 +927,18 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
value->addIncoming(then_value, then_value_block);
value->addIncoming(else_value, else_value_block);
return value;
+ } else if (op->op.same_as(builtin::ret())) {
+ auto const* val = op->args[0].as<IntImmNode>();
+ ICHECK(val) << "the tir.ret should be transformed to return zero "
+ << "before the llvm code generation.";
+ ICHECK_EQ(val->value, 0) << "the tir.ret should be transformed to "
+ << "return zero before the llvm code generation.";
+ builder_->CreateRet(ConstInt32(0));
+ // LLVM allows exactly one terminator in a single basic block
+ // append a new dummy basic block to avoid error.
+ llvm::BasicBlock* ret_dummy = llvm::BasicBlock::Create(*ctx_, "ret_dummy", function_);
+ builder_->SetInsertPoint(ret_dummy);
+ return ret_dummy;
} else if (op->op.same_as(builtin::reinterpret())) {
llvm::Type* target = DTypeToLLVMType(op->dtype);
return builder_->CreateBitCast(MakeValue(op->args[0]), target);
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 796b113..1117571 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -42,6 +42,10 @@ TIR_DEFINE_BUILTIN_FUNC(reinterpret)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_num_inputs(1);
+TIR_DEFINE_BUILTIN_FUNC(ret)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kControlJump))
+ .set_num_inputs(1);
+
TIR_DEFINE_BUILTIN_FUNC(likely)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kExprAnnotation))
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index b576fe4..9fcb071 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -145,6 +145,10 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
}
}
+PrimExpr ret(PrimExpr value, Span span) {
+ return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span);
+}
+
// maximum and min limits
PrimExpr max_value(const DataType& dtype, Span span) {
using namespace tir;
diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc
index 7c4a8ef..adbe78a 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -41,6 +41,67 @@
namespace tvm {
namespace tir {
+class ReturnRewriter : public StmtMutator {
+ public:
+ explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {}
+
+ Stmt VisitStmt_(const ForNode* node) override {
+ if (node->for_type == ForType::Parallel) in_parallel_ += 1;
+ Stmt ret = StmtMutator::VisitStmt_(node);
+ if (node->for_type == ForType::Parallel) in_parallel_ -= 1;
+ return ret;
+ }
+
+ Stmt VisitStmt_(const EvaluateNode* node) override {
+ Stmt ret = StmtMutator::VisitStmt_(node);
+ const EvaluateNode* eval = ret.as<EvaluateNode>();
+ ICHECK(eval);
+ if (const CallNode* call = eval->value.as<CallNode>()) {
+ if (call->op.same_as(builtin::ret())) {
+ ICHECK_EQ(in_parallel_, 0) << "tir.ret cannot be used in parallel scope.";
+ ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument.";
+ ret = WriteToOut(call->args[0], ret_var_, ret_tcode_);
+ }
+ }
+ return ret;
+ }
+
+ private:
+ std::pair<int, PrimExpr> ConvertForFFI(PrimExpr val) {
+ // convert val's data type to FFI data type, return type code
+ DataType dtype = val.dtype();
+ if (dtype.is_int() || dtype.is_uint()) {
+ return {kTVMArgInt, Cast(DataType::Int(64), val)};
+ } else if (dtype.is_float()) {
+ return {kTVMArgFloat, Cast(DataType::Float(64), val)};
+ } else if (dtype.is_void()) {
+ return {kTVMNullptr, val};
+ } else {
+ LOG(FATAL) << "data type " << dtype << " not supported yet";
+ }
+ return {kTVMNullptr, val};
+ }
+
+ Stmt WriteToOut(PrimExpr val, Var ret_var, Var ret_tcode) {
+ auto p = ConvertForFFI(val);
+ int tcode = p.first;
+ val = p.second;
+ Stmt store_val = Store(ret_var_, val, 0, const_true());
+ Stmt store_tcode = Store(ret_tcode_, tcode, 0, const_true());
+ Stmt ret_zero = Evaluate(tvm::ret(0));
+ return SeqStmt({store_val, store_tcode, ret_zero});
+ }
+
+ Var ret_var_;
+ Var ret_tcode_;
+ int in_parallel_{0};
+};
+
+Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) {
+ ReturnRewriter rewriter(ret_var, ret_tcode);
+ return rewriter(body);
+}
+
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
}
@@ -182,8 +243,9 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
}
- Stmt body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
- StringImm(name_hint + "_compute_"), func_ptr->body);
+ Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode);
+ body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
+ StringImm(name_hint + "_compute_"), body);
// Set device context
if (vmap.count(device_id.get())) {
PrimExpr node = StringImm("default");
diff --git a/tests/python/unittest/test_tir_base.py b/tests/python/unittest/test_tir_base.py
new file mode 100644
index 0000000..6e081a1
--- /dev/null
+++ b/tests/python/unittest/test_tir_base.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import tir
+from tvm.ir.transform import PassContext
+
+
+def build_tir_func(func):
+ func = func.with_attr("global_symbol", "main")
+ pass_ctx = PassContext.current()
+ if pass_ctx.config.get("tir.noalias", True):
+ func = func.with_attr("tir.noalias", True)
+ mod = tvm.IRModule({"main": func})
+ func = tvm.build(mod)
+ return func
+
+
+def test_scalar_add():
+ a = tir.Var("a", "float32")
+ b = tir.Var("b", "float32")
+ c = a + b
+ c = tir.ret(c)
+ c = tir.Evaluate(c)
+ func = tir.PrimFunc([a, b], c)
+ func = build_tir_func(func)
+ out = func(1.0, 2.0)
+ assert out == 3.0
+
+
+def test_control_flow_jump():
+ ib = tvm.tir.ir_builder.create()
+ a = tir.Var("a", "float32")
+ b = tir.Var("b", "float32")
+ with ib.if_scope(True):
+ ib.emit(tir.Evaluate(tir.ret(a)))
+ ib.emit(tir.Evaluate(tir.ret(b)))
+ stmt = ib.get()
+ func = tir.PrimFunc([a, b], stmt)
+ func = build_tir_func(func)
+ out = func(1.0, 2.0)
+ assert out == 1.0
+
+
+if __name__ == "__main__":
+ test_scalar_add()
+ test_control_flow_jump()