You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/01/16 17:51:26 UTC

[tvm] branch main updated: [TIR] Support Return in TIR (#7084)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 052ad3d  [TIR] Support Return in TIR (#7084)
052ad3d is described below

commit 052ad3d92d20abdf221600005c2ccb130e39b6b4
Author: ziheng <zi...@apache.org>
AuthorDate: Sat Jan 16 09:51:04 2021 -0800

    [TIR] Support Return in TIR (#7084)
---
 include/tvm/tir/builtin.h              |  4 +++
 include/tvm/tir/op.h                   |  9 +++++
 include/tvm/tir/op_attr_types.h        |  6 +++-
 python/tvm/tir/__init__.py             |  2 +-
 python/tvm/tir/op.py                   | 28 +++++++++++----
 src/target/llvm/codegen_llvm.cc        | 12 +++++++
 src/tir/op/builtin.cc                  |  4 +++
 src/tir/op/op.cc                       |  4 +++
 src/tir/transforms/make_packed_api.cc  | 66 ++++++++++++++++++++++++++++++++--
 tests/python/unittest/test_tir_base.py | 60 +++++++++++++++++++++++++++++++
 10 files changed, 185 insertions(+), 10 deletions(-)

diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index a150595..6a40d86 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -42,6 +42,10 @@ namespace tir {
 /*! \brief Collection of builtin intrinsics as ops */
 namespace builtin {
 /*!
+ * \brief Return value.
+ */
+TVM_DLL const Op& ret();
+/*!
  * \brief Reinterpret the value using the target type.
  */
 TVM_DLL const Op& reinterpret();
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 4a907fc..b5a62c9 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -71,6 +71,15 @@ TVM_DLL Type GetType(const PrimExpr& expr);
 TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type);
 
 /*!
+ * \brief Return the value.
+ *
+ * \param value The returned value.
+ * \param span The location of this operation in the source.
+ * \return The return expression.
+ */
+TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span());
+
+/*!
  * Query the maximum possible value of dtype.
  * \param dtype The data type.
  * \param span The location of this operation in the source.
diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h
index ec7fc17..3dcc4b9 100644
--- a/include/tvm/tir/op_attr_types.h
+++ b/include/tvm/tir/op_attr_types.h
@@ -74,7 +74,11 @@ enum class CallEffectKind : int {
   /*!
    * \brief Embed opaque information in the Expr, cannot be codegen.
    */
-  kEmbedInfo = 5
+  kEmbedInfo = 5,
+  /*!
+   * \brief Function that changes control flow
+   */
+  kControlJump = 6,
 };
 
 /*! \brief Use integer to record the kind. */
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 1aac55f..901c89e 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -35,7 +35,7 @@ from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
 from .function import PrimFunc
 
 from .op import call_packed, call_intrin, call_pure_extern, call_extern
-from .op import call_llvm_intrin, call_llvm_pure_intrin, all, any, min_value, max_value, trace
+from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
 from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp
 from .op import sin, sinh, asin, asinh
 from .op import cos, cosh, acos, acosh
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index ca61be4..182264f 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -221,6 +221,22 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
     )
 
 
+def ret(val):
+    """Create a tir return expression
+
+    Parameters
+    ----------
+    val : Expr
+        The returned tir expression, whose data type is int, float or void pointer.
+
+    Returns
+    -------
+    ret : PrimExpr
+        The return expression
+    """
+    return call_intrin(val.dtype, "tir.ret", val)
+
+
 def any(*args, span=None):
     """Create a new experssion of the union of all conditions in the arguments
 
@@ -241,10 +257,10 @@ def any(*args, span=None):
         raise ValueError("Any must take at least 1 argument")
     if len(args) == 1:
         return args[0]
-    ret = _ffi_api._OpOr(args[0], args[1], span)
+    val = _ffi_api._OpOr(args[0], args[1], span)
     for i in range(2, len(args)):
-        ret = _ffi_api._OpOr(ret, args[i], span)
-    return ret
+        val = _ffi_api._OpOr(val, args[i], span)
+    return val
 
 
 def all(*args, span=None):
@@ -268,10 +284,10 @@ def all(*args, span=None):
         raise ValueError("Any must take at least 1 argument")
     if len(args) == 1:
         return args[0]
-    ret = _ffi_api._OpAnd(args[0], args[1], span)
+    val = _ffi_api._OpAnd(args[0], args[1], span)
     for i in range(2, len(args)):
-        ret = _ffi_api._OpAnd(ret, args[i], span)
-    return ret
+        val = _ffi_api._OpAnd(val, args[i], span)
+    return val
 
 
 @tvm._ffi.register_func("tvm.default_trace_action")
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 70f094a..34f3897 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -927,6 +927,18 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
     value->addIncoming(then_value, then_value_block);
     value->addIncoming(else_value, else_value_block);
     return value;
+  } else if (op->op.same_as(builtin::ret())) {
+    auto const* val = op->args[0].as<IntImmNode>();
+    ICHECK(val) << "the tir.ret should be transformed to return zero "
+                << "before the llvm code generation.";
+    ICHECK_EQ(val->value, 0) << "the tir.ret should be transformed to "
+                             << "return zero before the llvm code generation.";
+    builder_->CreateRet(ConstInt32(0));
+    // LLVM allows exactly one terminator in a single basic block
+    // append a new dummy basic block to avoid error.
+    llvm::BasicBlock* ret_dummy = llvm::BasicBlock::Create(*ctx_, "ret_dummy", function_);
+    builder_->SetInsertPoint(ret_dummy);
+    return ret_dummy;
   } else if (op->op.same_as(builtin::reinterpret())) {
     llvm::Type* target = DTypeToLLVMType(op->dtype);
     return builder_->CreateBitCast(MakeValue(op->args[0]), target);
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 796b113..1117571 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -42,6 +42,10 @@ TIR_DEFINE_BUILTIN_FUNC(reinterpret)
     .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
     .set_num_inputs(1);
 
+TIR_DEFINE_BUILTIN_FUNC(ret)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kControlJump))
+    .set_num_inputs(1);
+
 TIR_DEFINE_BUILTIN_FUNC(likely)
     .set_num_inputs(1)
     .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kExprAnnotation))
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index b576fe4..9fcb071 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -145,6 +145,10 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) {  // NOLINT(*)
   }
 }
 
+PrimExpr ret(PrimExpr value, Span span) {
+  return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span);
+}
+
 // maximum and min limits
 PrimExpr max_value(const DataType& dtype, Span span) {
   using namespace tir;
diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc
index 7c4a8ef..adbe78a 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -41,6 +41,67 @@
 namespace tvm {
 namespace tir {
 
+class ReturnRewriter : public StmtMutator {
+ public:
+  explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {}
+
+  Stmt VisitStmt_(const ForNode* node) override {
+    if (node->for_type == ForType::Parallel) in_parallel_ += 1;
+    Stmt ret = StmtMutator::VisitStmt_(node);
+    if (node->for_type == ForType::Parallel) in_parallel_ -= 1;
+    return ret;
+  }
+
+  Stmt VisitStmt_(const EvaluateNode* node) override {
+    Stmt ret = StmtMutator::VisitStmt_(node);
+    const EvaluateNode* eval = ret.as<EvaluateNode>();
+    ICHECK(eval);
+    if (const CallNode* call = eval->value.as<CallNode>()) {
+      if (call->op.same_as(builtin::ret())) {
+        ICHECK_EQ(in_parallel_, 0) << "tir.ret cannot be used in parallel scope.";
+        ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument.";
+        ret = WriteToOut(call->args[0], ret_var_, ret_tcode_);
+      }
+    }
+    return ret;
+  }
+
+ private:
+  std::pair<int, PrimExpr> ConvertForFFI(PrimExpr val) {
+    // convert val's data type to FFI data type, return type code
+    DataType dtype = val.dtype();
+    if (dtype.is_int() || dtype.is_uint()) {
+      return {kTVMArgInt, Cast(DataType::Int(64), val)};
+    } else if (dtype.is_float()) {
+      return {kTVMArgFloat, Cast(DataType::Float(64), val)};
+    } else if (dtype.is_void()) {
+      return {kTVMNullptr, val};
+    } else {
+      LOG(FATAL) << "data type " << dtype << " not supported yet";
+    }
+    return {kTVMNullptr, val};
+  }
+
+  Stmt WriteToOut(PrimExpr val, Var ret_var, Var ret_tcode) {
+    auto p = ConvertForFFI(val);
+    int tcode = p.first;
+    val = p.second;
+    Stmt store_val = Store(ret_var_, val, 0, const_true());
+    Stmt store_tcode = Store(ret_tcode_, tcode, 0, const_true());
+    Stmt ret_zero = Evaluate(tvm::ret(0));
+    return SeqStmt({store_val, store_tcode, ret_zero});
+  }
+
+  Var ret_var_;
+  Var ret_tcode_;
+  int in_parallel_{0};
+};
+
+Stmt RewriteReturn(Stmt body, Var ret_var, Var ret_tcode) {
+  ReturnRewriter rewriter(ret_var, ret_tcode);
+  return rewriter(body);
+}
+
 inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
   return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
 }
@@ -182,8 +243,9 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
     func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
   }
 
-  Stmt body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
-                       StringImm(name_hint + "_compute_"), func_ptr->body);
+  Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode);
+  body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
+                  StringImm(name_hint + "_compute_"), body);
   // Set device context
   if (vmap.count(device_id.get())) {
     PrimExpr node = StringImm("default");
diff --git a/tests/python/unittest/test_tir_base.py b/tests/python/unittest/test_tir_base.py
new file mode 100644
index 0000000..6e081a1
--- /dev/null
+++ b/tests/python/unittest/test_tir_base.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import tir
+from tvm.ir.transform import PassContext
+
+
+def build_tir_func(func):
+    func = func.with_attr("global_symbol", "main")
+    pass_ctx = PassContext.current()
+    if pass_ctx.config.get("tir.noalias", True):
+        func = func.with_attr("tir.noalias", True)
+    mod = tvm.IRModule({"main": func})
+    func = tvm.build(mod)
+    return func
+
+
+def test_scalar_add():
+    a = tir.Var("a", "float32")
+    b = tir.Var("b", "float32")
+    c = a + b
+    c = tir.ret(c)
+    c = tir.Evaluate(c)
+    func = tir.PrimFunc([a, b], c)
+    func = build_tir_func(func)
+    out = func(1.0, 2.0)
+    assert out == 3.0
+
+
+def test_control_flow_jump():
+    ib = tvm.tir.ir_builder.create()
+    a = tir.Var("a", "float32")
+    b = tir.Var("b", "float32")
+    with ib.if_scope(True):
+        ib.emit(tir.Evaluate(tir.ret(a)))
+    ib.emit(tir.Evaluate(tir.ret(b)))
+    stmt = ib.get()
+    func = tir.PrimFunc([a, b], stmt)
+    func = build_tir_func(func)
+    out = func(1.0, 2.0)
+    assert out == 1.0
+
+
+if __name__ == "__main__":
+    test_scalar_add()
+    test_control_flow_jump()