You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zi...@apache.org on 2021/04/30 02:05:23 UTC
[tvm] branch main updated: [TIR][TRANSFORM] Return value support in
tir.tvm_call_packed (#7932)
This is an automated email from the ASF dual-hosted git repository.
ziheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 62309e5 [TIR][TRANSFORM] Return value support in tir.tvm_call_packed (#7932)
62309e5 is described below
commit 62309e51f5b88722e0d3dc737ebe0094e09eff4b
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Thu Apr 29 22:04:58 2021 -0400
[TIR][TRANSFORM] Return value support in tir.tvm_call_packed (#7932)
This PR fixes the return value support in tir.tvm_call_packed
- Clarified the semantics of the intrinsics
- Fix a problem when lowering call packed with nested scopes(let bindings)
- Added regression tests to cover the changes
---
include/tvm/tir/builtin.h | 38 ++++++-----
python/tvm/tir/ir_builder.py | 20 ++++++
src/tir/transforms/lower_tvm_builtin.cc | 77 +++++++++++++---------
.../test_tir_transform_lower_tvm_builtin.py | 40 ++++++++++-
4 files changed, 126 insertions(+), 49 deletions(-)
diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index 6a40d86..d8248d4 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -334,11 +334,14 @@ TVM_DLL const Op& tvm_stack_make_array();
/*!
* \brief See pesudo code
*
- * int tvm_call_packed(name, TVMValue* args) {
+ * return_type tvm_call_packed(name, TVMValue* args) {
+ * TVMValue ret_value;
+ * int ret_code;
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
- * (*f)(args, type_code_of(args), len(args));
- * return 0;
+ * (*f)(args, type_code_of(args), len(args), &ret_value, &ret_code);
+ * // return type can be int, float, handle.
+ * return cast(return_type, ret_value.v_return_type);
* }
*/
TVM_DLL const Op& tvm_call_packed();
@@ -346,11 +349,12 @@ TVM_DLL const Op& tvm_call_packed();
/*!
* \brief See pesudo code
*
- * int tvm_call_trace_packed(name, TVMValue* args) {
+ * return_type tvm_call_trace_packed(name, TVMValue* args) {
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* (*f)(args, type_code_of(args), len(args));
- * return 0;
+ * // return type can be int, float, handle.
+ * return cast(return_type, ret_value.v_return_type);
* }
*/
TVM_DLL const Op& tvm_call_trace_packed();
@@ -372,16 +376,18 @@ TVM_DLL const Op& tvm_thread_context();
* \brief Lowered version of call packed, the space of value and
* type codes are explicitly allocated.
*
- * int tvm_call_packed_lowered(name,
- * TVMValue* value_stack,
- * int* tcode_stack,
- * int begin,
- * int end) {
+ * return_type tvm_call_packed_lowered(name,
+ * TVMValue* value_stack,
+ * int* tcode_stack,
+ * int begin,
+ * int end) {
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* f->CallPacked(TVMArgs(value_stack[begin:end],
* tcode_stack[begin:end]),
* TVMRetValue(value_stack + end, tcode_stack + end));
+ * // return type can be int, float, handle.
+ * return cast(return_type, load_return_from(tcode_stack + end))
* }
*/
TVM_DLL const Op& tvm_call_packed_lowered();
@@ -391,16 +397,18 @@ TVM_DLL const Op& tvm_call_packed_lowered();
* type codes are explicitly allocated. The return value is the
* (end - 1) value on the stack.
*
- * int tvm_call_trace_packed_lowered(name,
- * TVMValue* value_stack,
- * int* tcode_stack,
- * int begin,
- * int end) {
+ * return_type tvm_call_trace_packed_lowered(name,
+ * TVMValue* value_stack,
+ * int* tcode_stack,
+ * int begin,
+ * int end) {
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* f->CallPacked(TVMArgs(value_stack[begin:end],
* tcode_stack[begin:end]),
* TVMRetValue(value_stack + end, tcode_stack + end));
+ * // return type can be int, float, handle.
+ * return cast(return_type, load_return_from(tcode_stack + end))
* }
*/
TVM_DLL const Op& tvm_call_trace_packed_lowered();
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index 2ecbded..4934bf0 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -374,6 +374,26 @@ class IRBuilder(object):
return WithScope(None, _exit_cb)
+ def let(self, var_name, value):
+ """Create a new let stmt binding.
+
+ Parameters
+ ----------
+ var_name : str
+ The name of the variable
+
+ value : PrimExpr
+ The value to be bound
+
+ Returns
+ -------
+ var : tvm.tir.Var
+ The var that can be in for future emits.
+ """
+ var = _expr.Var(var_name, dtype=value.dtype)
+ self.emit(lambda x: _stmt.LetStmt(var, value, x))
+ return var
+
def allocate(self, dtype, shape, name="buf", scope=None):
"""Create a allocate statement.
diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc
index c40fd7e..8d2857e 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -89,14 +89,19 @@ class BuiltinLower : public StmtExprMutator {
}
Stmt VisitStmt(const Stmt& s) final {
+ // allocate space to hold prepare stmts before s
+ prep_seq_stack_.emplace_back(std::vector<Stmt>());
+
auto stmt = StmtExprMutator::VisitStmt(s);
auto& scope = alloca_scope_.back();
ICHECK_EQ(scope.run_shape_stack, -1);
ICHECK_EQ(scope.run_array_stack, 0);
- if (prep_seq_.size() != 0) {
- Stmt ret = SeqStmt::Flatten(prep_seq_, stmt);
- prep_seq_.clear();
+ auto prep_seq = std::move(prep_seq_stack_.back());
+ prep_seq_stack_.pop_back();
+
+ if (prep_seq.size() != 0) {
+ Stmt ret = SeqStmt::Flatten(prep_seq, stmt);
return ret;
} else {
return stmt;
@@ -192,6 +197,7 @@ class BuiltinLower : public StmtExprMutator {
// if args.size() == 0, it represents a scalar shape ()
ICHECK(!alloca_scope_.empty());
auto& scope = alloca_scope_.back();
+ auto& prep_seq = prep_seq_stack_.back();
if (scope.run_shape_stack == -1) {
scope.run_shape_stack = 0;
}
@@ -201,8 +207,8 @@ class BuiltinLower : public StmtExprMutator {
op = expr.as<CallNode>();
// no need to perform any store for a scalar shape
for (size_t i = 0; i < op->args.size(); ++i) {
- prep_seq_.emplace_back(Store(scope.stack_shape, cast(DataType::Int(64), op->args[i]),
- ConstInt32(stack_begin + i), const_true(1)));
+ prep_seq.emplace_back(Store(scope.stack_shape, cast(DataType::Int(64), op->args[i]),
+ ConstInt32(stack_begin + i), const_true(1)));
}
return AddressOffset(scope.stack_shape, DataType::Int(64), stack_begin);
}
@@ -210,48 +216,54 @@ class BuiltinLower : public StmtExprMutator {
PrimExpr MakeArray(const CallNode* op) {
ICHECK(!alloca_scope_.empty());
auto& scope = alloca_scope_.back();
+ auto& prep_seq = prep_seq_stack_.back();
+
size_t idx = scope.run_array_stack;
scope.run_array_stack += 1;
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
- prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrData, op->args[0]));
- prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrShape, op->args[1]));
+
+ prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrData, op->args[0]));
+ prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrShape, op->args[1]));
PrimExpr strides = op->args[2];
if (!strides.defined() || is_zero(strides)) {
strides = make_zero(DataType::Handle());
}
- prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrStrides, strides));
- prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrNDim, op->args[3]));
+ prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrStrides, strides));
+ prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrNDim, op->args[3]));
DataType dtype = op->args[4].dtype();
- prep_seq_.emplace_back(
+ prep_seq.emplace_back(
TVMStructSet(scope.stack_array, idx, builtin::kArrTypeCode,
make_const(DataType::UInt(8), static_cast<int>(dtype.code()))));
- prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeBits,
- make_const(DataType::UInt(8), dtype.bits())));
- prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeLanes,
- make_const(DataType::UInt(16), dtype.lanes())));
+ prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeBits,
+ make_const(DataType::UInt(8), dtype.bits())));
+ prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeLanes,
+ make_const(DataType::UInt(16), dtype.lanes())));
// set byte offset
int data_bytes = GetVectorBytes(dtype);
PrimExpr byte_offset = op->args[5];
if (!is_zero(byte_offset)) {
byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes);
}
- prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrByteOffset,
- cast(DataType::UInt(64), byte_offset)));
+ prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrByteOffset,
+ cast(DataType::UInt(64), byte_offset)));
ICHECK(device_type_.defined()) << "Unknown device type in current IR";
ICHECK(device_id_.defined()) << "Unknown device id in current IR";
- prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceId,
- cast(DataType::Int(32), device_id_)));
- prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceType,
- cast(DataType::Int(32), device_type_)));
+ prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceId,
+ cast(DataType::Int(32), device_id_)));
+ prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceType,
+ cast(DataType::Int(32), device_type_)));
return TVMStructGet(DataType::Handle(), scope.stack_array, idx, builtin::kArrAddr);
}
// call packed.
PrimExpr MakeCallPacked(const CallNode* op) {
auto& scope = alloca_scope_.back();
+ auto& prep_seq = prep_seq_stack_.back();
+
int64_t restore_shape_stack = scope.run_shape_stack;
size_t restore_array_stack = scope.run_array_stack;
size_t arg_stack_begin = scope.run_arg_stack;
+
scope.run_arg_stack += op->args.size();
// Specially handle the buffer packed intrinsic
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
@@ -264,15 +276,15 @@ class BuiltinLower : public StmtExprMutator {
if (t != api_type) {
arg = Cast(api_type, arg);
}
- prep_seq_.emplace_back(TVMStructSet(scope.stack_value,
- static_cast<int>(arg_stack_begin + i - 1),
- builtin::kTVMValueContent, arg));
+ prep_seq.emplace_back(TVMStructSet(scope.stack_value,
+ static_cast<int>(arg_stack_begin + i - 1),
+ builtin::kTVMValueContent, arg));
int arg_tcode = api_type.code();
if (api_type.is_handle() && arg.as<StringImmNode>()) {
arg_tcode = kTVMStr;
}
if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle;
- prep_seq_.emplace_back(
+ prep_seq.emplace_back(
Store(scope.stack_tcode, ConstInt32(arg_tcode), stack_index, const_true(1)));
}
// UPDATE stack value
@@ -285,12 +297,15 @@ class BuiltinLower : public StmtExprMutator {
Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, scope.stack_tcode,
ConstInt32(arg_stack_begin),
ConstInt32(arg_stack_begin + op->args.size() - 1)};
- return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args);
+ // call_packed_lowered needs to do the type casting properly
+ return Call(op->dtype, builtin::tvm_call_packed_lowered(), packed_args);
}
PrimExpr MakeCallTracePacked(const CallNode* op) {
ICHECK(!alloca_scope_.empty());
auto& scope = alloca_scope_.back();
+ auto& prep_seq = prep_seq_stack_.back();
+
int64_t restore_shape_stack = scope.run_shape_stack;
size_t restore_array_stack = scope.run_array_stack;
size_t arg_stack_begin = scope.run_arg_stack;
@@ -307,12 +322,12 @@ class BuiltinLower : public StmtExprMutator {
if (t != api_type) {
arg = Cast(api_type, arg);
}
- prep_seq_.emplace_back(TVMStructSet(scope.stack_value,
- static_cast<int>(arg_stack_begin + i - 1),
- builtin::kTVMValueContent, arg));
+ prep_seq.emplace_back(TVMStructSet(scope.stack_value,
+ static_cast<int>(arg_stack_begin + i - 1),
+ builtin::kTVMValueContent, arg));
int arg_tcode = api_type.code();
ICHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
- prep_seq_.emplace_back(
+ prep_seq.emplace_back(
Store(scope.stack_tcode, ConstInt32(arg_tcode), stack_index, const_true(1)));
}
// UPDATE stack value
@@ -344,8 +359,8 @@ class BuiltinLower : public StmtExprMutator {
return false;
}
- // The prepration sequence to be emitted.
- std::vector<Stmt> prep_seq_;
+ // The prepration sequence to be emitted before the current statement.
+ std::vector<std::vector<Stmt>> prep_seq_stack_;
PrimExpr device_type_;
PrimExpr device_id_;
diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
index 8b2b26a..d6b427a 100644
--- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
+++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
@@ -133,11 +133,45 @@ def check_packed_func(target="llvm"):
tvm.ir.assert_structural_equal(alloca_shape, expected_stmt, map_free_vars=True)
-def test_packed_func():
+def test_lower_packed_func():
check_packed_func("llvm")
check_packed_func("stackvm")
+@tvm.testing.requires_llvm
+def test_call_packed_return_non_i32():
+ # This call packed that return non i32 types
+ expected_value = np.array([1.2, 1.4], dtype="float32")
+
+ def packed_echo(value):
+ return tvm.tir.call_intrin(
+ value.dtype, tvm.ir.Op.get("tir.tvm_call_packed"), "testing.echo", value
+ )
+
+ def build_tir():
+ Ab = tvm.tir.decl_buffer((2,), "float32")
+ ib = tvm.tir.ir_builder.create()
+ Aptr = ib.buffer_ptr(Ab)
+ # return f32
+ # Aptr[0] = testing.echo(expected_value[0])
+ Aptr[0] = packed_echo(tvm.tir.const(expected_value[0], "float32"))
+ # return handle
+ # let Aptr_var = testing.echo(Aptr) in Aptr_var[1] = expected_value[1]
+ Aptr_var = ib.let("Aptr_dup", packed_echo(Aptr.asobject()))
+ ib.emit(tvm.tir.Store(Aptr, tvm.tir.const(expected_value[1], "float32"), 1))
+
+ stmt = ib.get()
+ return tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "packed_test")
+ )
+
+ mod = build_tir()
+ f = tvm.build(mod, None, "llvm")
+ a = tvm.nd.array(np.zeros(2, dtype="float32"))
+ f(a)
+ tvm.testing.assert_allclose(a.asnumpy(), expected_value)
+
+
if __name__ == "__main__":
- # Test cases for issue: https://github.com/apache/tvm/issues/7246
- test_packed_func()
+ test_call_packed_return_non_i32()
+ test_lower_packed_func()