You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zi...@apache.org on 2021/04/30 02:05:23 UTC

[tvm] branch main updated: [TIR][TRANSFORM] Return value support in tir.tvm_call_packed (#7932)

This is an automated email from the ASF dual-hosted git repository.

ziheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 62309e5  [TIR][TRANSFORM] Return value support in tir.tvm_call_packed (#7932)
62309e5 is described below

commit 62309e51f5b88722e0d3dc737ebe0094e09eff4b
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Thu Apr 29 22:04:58 2021 -0400

    [TIR][TRANSFORM] Return value support in tir.tvm_call_packed (#7932)
    
    This PR fixes the return value support in tir.tvm_call_packed
    
    - Clarified the semantics of the intrinsics
    - Fix a problem when lowering call packed with nested scopes(let bindings)
    - Added regression tests to cover the changes
---
 include/tvm/tir/builtin.h                          | 38 ++++++-----
 python/tvm/tir/ir_builder.py                       | 20 ++++++
 src/tir/transforms/lower_tvm_builtin.cc            | 77 +++++++++++++---------
 .../test_tir_transform_lower_tvm_builtin.py        | 40 ++++++++++-
 4 files changed, 126 insertions(+), 49 deletions(-)

diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index 6a40d86..d8248d4 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -334,11 +334,14 @@ TVM_DLL const Op& tvm_stack_make_array();
 /*!
  * \brief See pesudo code
  *
- *  int tvm_call_packed(name, TVMValue* args) {
+ *  return_type tvm_call_packed(name, TVMValue* args) {
+ *     TVMValue ret_value;
+ *     int ret_code;
  *     ModuleNode* env = GetCurrentEnv();
  *     const PackedFunc* f = env->GetFuncFromEnv(name);
- *     (*f)(args, type_code_of(args), len(args));
- *     return 0;
+ *     (*f)(args, type_code_of(args), len(args), &ret_value, &ret_code);
+ *     // return type can be int, float, handle.
+ *     return cast(return_type, ret_value.v_return_type);
  *  }
  */
 TVM_DLL const Op& tvm_call_packed();
@@ -346,11 +349,12 @@ TVM_DLL const Op& tvm_call_packed();
 /*!
  * \brief See pesudo code
  *
- *  int tvm_call_trace_packed(name, TVMValue* args) {
+ *  return_type tvm_call_trace_packed(name, TVMValue* args) {
  *     ModuleNode* env = GetCurrentEnv();
  *     const PackedFunc* f = env->GetFuncFromEnv(name);
  *     (*f)(args, type_code_of(args), len(args));
- *     return 0;
+ *     // return type can be int, float, handle.
+ *     return cast(return_type, ret_value.v_return_type);
  *  }
  */
 TVM_DLL const Op& tvm_call_trace_packed();
@@ -372,16 +376,18 @@ TVM_DLL const Op& tvm_thread_context();
  * \brief Lowered version of call packed, the space of value and
  *  type codes are explicitly allocated.
  *
- *  int tvm_call_packed_lowered(name,
- *                              TVMValue* value_stack,
- *                              int* tcode_stack,
- *                              int begin,
- *                              int end) {
+ *  return_type tvm_call_packed_lowered(name,
+ *                                      TVMValue* value_stack,
+ *                                      int* tcode_stack,
+ *                                      int begin,
+ *                                      int end) {
  *     ModuleNode* env = GetCurrentEnv();
  *     const PackedFunc* f = env->GetFuncFromEnv(name);
  *     f->CallPacked(TVMArgs(value_stack[begin:end],
  *                           tcode_stack[begin:end]),
  *                   TVMRetValue(value_stack + end, tcode_stack + end));
+ *     // return type can be int, float, handle.
+ *     return cast(return_type, load_return_from(tcode_stack + end))
  *  }
  */
 TVM_DLL const Op& tvm_call_packed_lowered();
@@ -391,16 +397,18 @@ TVM_DLL const Op& tvm_call_packed_lowered();
  *  type codes are explicitly allocated. The return value is the
  *  (end - 1) value on the stack.
  *
- *  int tvm_call_trace_packed_lowered(name,
- *                                    TVMValue* value_stack,
- *                                    int* tcode_stack,
- *                                    int begin,
- *                                    int end) {
+ *  return_type tvm_call_trace_packed_lowered(name,
+ *                                            TVMValue* value_stack,
+ *                                            int* tcode_stack,
+ *                                            int begin,
+ *                                            int end) {
  *     ModuleNode* env = GetCurrentEnv();
  *     const PackedFunc* f = env->GetFuncFromEnv(name);
  *     f->CallPacked(TVMArgs(value_stack[begin:end],
  *                           tcode_stack[begin:end]),
  *                   TVMRetValue(value_stack + end, tcode_stack + end));
+ *     // return type can be int, float, handle.
+ *     return cast(return_type, load_return_from(tcode_stack + end))
  *  }
  */
 TVM_DLL const Op& tvm_call_trace_packed_lowered();
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index 2ecbded..4934bf0 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -374,6 +374,26 @@ class IRBuilder(object):
 
         return WithScope(None, _exit_cb)
 
+    def let(self, var_name, value):
+        """Create a new let stmt binding.
+
+        Parameters
+        ----------
+        var_name : str
+            The name of the variable
+
+        value : PrimExpr
+            The value to be bound
+
+        Returns
+        -------
+        var : tvm.tir.Var
+           The var that can be in for future emits.
+        """
+        var = _expr.Var(var_name, dtype=value.dtype)
+        self.emit(lambda x: _stmt.LetStmt(var, value, x))
+        return var
+
     def allocate(self, dtype, shape, name="buf", scope=None):
         """Create a allocate statement.
 
diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc
index c40fd7e..8d2857e 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -89,14 +89,19 @@ class BuiltinLower : public StmtExprMutator {
   }
 
   Stmt VisitStmt(const Stmt& s) final {
+    // allocate space to hold prepare stmts before s
+    prep_seq_stack_.emplace_back(std::vector<Stmt>());
+
     auto stmt = StmtExprMutator::VisitStmt(s);
     auto& scope = alloca_scope_.back();
     ICHECK_EQ(scope.run_shape_stack, -1);
     ICHECK_EQ(scope.run_array_stack, 0);
 
-    if (prep_seq_.size() != 0) {
-      Stmt ret = SeqStmt::Flatten(prep_seq_, stmt);
-      prep_seq_.clear();
+    auto prep_seq = std::move(prep_seq_stack_.back());
+    prep_seq_stack_.pop_back();
+
+    if (prep_seq.size() != 0) {
+      Stmt ret = SeqStmt::Flatten(prep_seq, stmt);
       return ret;
     } else {
       return stmt;
@@ -192,6 +197,7 @@ class BuiltinLower : public StmtExprMutator {
     // if args.size() == 0, it represents a scalar shape ()
     ICHECK(!alloca_scope_.empty());
     auto& scope = alloca_scope_.back();
+    auto& prep_seq = prep_seq_stack_.back();
     if (scope.run_shape_stack == -1) {
       scope.run_shape_stack = 0;
     }
@@ -201,8 +207,8 @@ class BuiltinLower : public StmtExprMutator {
     op = expr.as<CallNode>();
     // no need to perform any store for a scalar shape
     for (size_t i = 0; i < op->args.size(); ++i) {
-      prep_seq_.emplace_back(Store(scope.stack_shape, cast(DataType::Int(64), op->args[i]),
-                                   ConstInt32(stack_begin + i), const_true(1)));
+      prep_seq.emplace_back(Store(scope.stack_shape, cast(DataType::Int(64), op->args[i]),
+                                  ConstInt32(stack_begin + i), const_true(1)));
     }
     return AddressOffset(scope.stack_shape, DataType::Int(64), stack_begin);
   }
@@ -210,48 +216,54 @@ class BuiltinLower : public StmtExprMutator {
   PrimExpr MakeArray(const CallNode* op) {
     ICHECK(!alloca_scope_.empty());
     auto& scope = alloca_scope_.back();
+    auto& prep_seq = prep_seq_stack_.back();
+
     size_t idx = scope.run_array_stack;
     scope.run_array_stack += 1;
     PrimExpr expr = StmtExprMutator::VisitExpr_(op);
     op = expr.as<CallNode>();
-    prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrData, op->args[0]));
-    prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrShape, op->args[1]));
+
+    prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrData, op->args[0]));
+    prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrShape, op->args[1]));
     PrimExpr strides = op->args[2];
     if (!strides.defined() || is_zero(strides)) {
       strides = make_zero(DataType::Handle());
     }
-    prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrStrides, strides));
-    prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrNDim, op->args[3]));
+    prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrStrides, strides));
+    prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrNDim, op->args[3]));
     DataType dtype = op->args[4].dtype();
-    prep_seq_.emplace_back(
+    prep_seq.emplace_back(
         TVMStructSet(scope.stack_array, idx, builtin::kArrTypeCode,
                      make_const(DataType::UInt(8), static_cast<int>(dtype.code()))));
-    prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeBits,
-                                        make_const(DataType::UInt(8), dtype.bits())));
-    prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeLanes,
-                                        make_const(DataType::UInt(16), dtype.lanes())));
+    prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeBits,
+                                       make_const(DataType::UInt(8), dtype.bits())));
+    prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeLanes,
+                                       make_const(DataType::UInt(16), dtype.lanes())));
     // set byte offset
     int data_bytes = GetVectorBytes(dtype);
     PrimExpr byte_offset = op->args[5];
     if (!is_zero(byte_offset)) {
       byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes);
     }
-    prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrByteOffset,
-                                        cast(DataType::UInt(64), byte_offset)));
+    prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrByteOffset,
+                                       cast(DataType::UInt(64), byte_offset)));
     ICHECK(device_type_.defined()) << "Unknown device type in current IR";
     ICHECK(device_id_.defined()) << "Unknown device id in current IR";
-    prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceId,
-                                        cast(DataType::Int(32), device_id_)));
-    prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceType,
-                                        cast(DataType::Int(32), device_type_)));
+    prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceId,
+                                       cast(DataType::Int(32), device_id_)));
+    prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceType,
+                                       cast(DataType::Int(32), device_type_)));
     return TVMStructGet(DataType::Handle(), scope.stack_array, idx, builtin::kArrAddr);
   }
   // call packed.
   PrimExpr MakeCallPacked(const CallNode* op) {
     auto& scope = alloca_scope_.back();
+    auto& prep_seq = prep_seq_stack_.back();
+
     int64_t restore_shape_stack = scope.run_shape_stack;
     size_t restore_array_stack = scope.run_array_stack;
     size_t arg_stack_begin = scope.run_arg_stack;
+
     scope.run_arg_stack += op->args.size();
     // Specially handle the buffer packed intrinsic
     PrimExpr expr = StmtExprMutator::VisitExpr_(op);
@@ -264,15 +276,15 @@ class BuiltinLower : public StmtExprMutator {
       if (t != api_type) {
         arg = Cast(api_type, arg);
       }
-      prep_seq_.emplace_back(TVMStructSet(scope.stack_value,
-                                          static_cast<int>(arg_stack_begin + i - 1),
-                                          builtin::kTVMValueContent, arg));
+      prep_seq.emplace_back(TVMStructSet(scope.stack_value,
+                                         static_cast<int>(arg_stack_begin + i - 1),
+                                         builtin::kTVMValueContent, arg));
       int arg_tcode = api_type.code();
       if (api_type.is_handle() && arg.as<StringImmNode>()) {
         arg_tcode = kTVMStr;
       }
       if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle;
-      prep_seq_.emplace_back(
+      prep_seq.emplace_back(
           Store(scope.stack_tcode, ConstInt32(arg_tcode), stack_index, const_true(1)));
     }
     // UPDATE stack value
@@ -285,12 +297,15 @@ class BuiltinLower : public StmtExprMutator {
     Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, scope.stack_tcode,
                                    ConstInt32(arg_stack_begin),
                                    ConstInt32(arg_stack_begin + op->args.size() - 1)};
-    return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args);
+    // call_packed_lowered needs to do the type casting properly
+    return Call(op->dtype, builtin::tvm_call_packed_lowered(), packed_args);
   }
 
   PrimExpr MakeCallTracePacked(const CallNode* op) {
     ICHECK(!alloca_scope_.empty());
     auto& scope = alloca_scope_.back();
+    auto& prep_seq = prep_seq_stack_.back();
+
     int64_t restore_shape_stack = scope.run_shape_stack;
     size_t restore_array_stack = scope.run_array_stack;
     size_t arg_stack_begin = scope.run_arg_stack;
@@ -307,12 +322,12 @@ class BuiltinLower : public StmtExprMutator {
       if (t != api_type) {
         arg = Cast(api_type, arg);
       }
-      prep_seq_.emplace_back(TVMStructSet(scope.stack_value,
-                                          static_cast<int>(arg_stack_begin + i - 1),
-                                          builtin::kTVMValueContent, arg));
+      prep_seq.emplace_back(TVMStructSet(scope.stack_value,
+                                         static_cast<int>(arg_stack_begin + i - 1),
+                                         builtin::kTVMValueContent, arg));
       int arg_tcode = api_type.code();
       ICHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
-      prep_seq_.emplace_back(
+      prep_seq.emplace_back(
           Store(scope.stack_tcode, ConstInt32(arg_tcode), stack_index, const_true(1)));
     }
     // UPDATE stack value
@@ -344,8 +359,8 @@ class BuiltinLower : public StmtExprMutator {
     return false;
   }
 
-  // The prepration sequence to be emitted.
-  std::vector<Stmt> prep_seq_;
+  // The prepration sequence to be emitted before the current statement.
+  std::vector<std::vector<Stmt>> prep_seq_stack_;
   PrimExpr device_type_;
   PrimExpr device_id_;
 
diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
index 8b2b26a..d6b427a 100644
--- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
+++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
@@ -133,11 +133,45 @@ def check_packed_func(target="llvm"):
     tvm.ir.assert_structural_equal(alloca_shape, expected_stmt, map_free_vars=True)
 
 
-def test_packed_func():
+def test_lower_packed_func():
     check_packed_func("llvm")
     check_packed_func("stackvm")
 
 
+@tvm.testing.requires_llvm
+def test_call_packed_return_non_i32():
+    # This call packed that return non i32 types
+    expected_value = np.array([1.2, 1.4], dtype="float32")
+
+    def packed_echo(value):
+        return tvm.tir.call_intrin(
+            value.dtype, tvm.ir.Op.get("tir.tvm_call_packed"), "testing.echo", value
+        )
+
+    def build_tir():
+        Ab = tvm.tir.decl_buffer((2,), "float32")
+        ib = tvm.tir.ir_builder.create()
+        Aptr = ib.buffer_ptr(Ab)
+        # return f32
+        # Aptr[0] = testing.echo(expected_value[0])
+        Aptr[0] = packed_echo(tvm.tir.const(expected_value[0], "float32"))
+        # return handle
+        # let Aptr_var = testing.echo(Aptr) in Aptr_var[1] = expected_value[1]
+        Aptr_var = ib.let("Aptr_dup", packed_echo(Aptr.asobject()))
+        ib.emit(tvm.tir.Store(Aptr, tvm.tir.const(expected_value[1], "float32"), 1))
+
+        stmt = ib.get()
+        return tvm.IRModule.from_expr(
+            tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "packed_test")
+        )
+
+    mod = build_tir()
+    f = tvm.build(mod, None, "llvm")
+    a = tvm.nd.array(np.zeros(2, dtype="float32"))
+    f(a)
+    tvm.testing.assert_allclose(a.asnumpy(), expected_value)
+
+
 if __name__ == "__main__":
-    # Test cases for issue: https://github.com/apache/tvm/issues/7246
-    test_packed_func()
+    test_call_packed_return_non_i32()
+    test_lower_packed_func()