You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2024/03/03 14:17:52 UTC
(tvm) branch main updated: [Relax][Runtime] Support Unpack API for NDArrayCache (#16648)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5718ff35ef [Relax][Runtime] Support Unpack API for NDArrayCache (#16648)
5718ff35ef is described below
commit 5718ff35ef5ba758a325cdbca191f36b84d0b549
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Sun Mar 3 22:17:46 2024 +0800
[Relax][Runtime] Support Unpack API for NDArrayCache (#16648)
As `Array` cannot be transferred through RPC protocol, we introduce
a new unpack API by directly passing all str through PackedFunc.
This PR also fixes a bug in `vm.builtin.ndarray_cache.update`
---
src/runtime/relax_vm/ndarray_cache_support.cc | 15 ++++++++++++++-
1 file changed, 14 insertions(+), 1 deletion(-)
diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc b/src/runtime/relax_vm/ndarray_cache_support.cc
index b389030cfe..fce40157e4 100644
--- a/src/runtime/relax_vm/ndarray_cache_support.cc
+++ b/src/runtime/relax_vm/ndarray_cache_support.cc
@@ -282,7 +282,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.update").set_body([](TVMArgs args,
for (int64_t i = 0; i < tensor->ndim; i++) {
shape.push_back(tensor->shape[i]);
}
- NDArray arr = NDArray::Empty(shape, tensor->dtype, tensor->device);
+ arr = NDArray::Empty(shape, tensor->dtype, tensor->device);
arr.CopyFrom(tensor);
TVMSynchronize(arr->device.device_type, arr->device.device_id, nullptr);
}
@@ -358,6 +358,19 @@ TVM_REGISTER_GLOBAL("vm.builtin.param_module_from_cache_by_name")
TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache").set_body_typed(ParamModuleNode::GetParams);
TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name")
.set_body_typed(ParamModuleNode::GetParamByName);
+TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name_unpacked")
+ .set_body([](TVMArgs args, TVMRetValue* rv) {
+ Array<String> names;
+ names.reserve(args.size());
+ for (int i = 0; i < args.size(); ++i) {
+ if (args[i].type_code() != kTVMStr) {
+ LOG(FATAL) << "ValueError: Expect string as input, but get " << args[i].type_code()
+ << " at " << i;
+ }
+ names.push_back(args[i]);
+ }
+ *rv = ParamModuleNode::GetParamByName(names);
+ });
} // namespace relax_vm
} // namespace runtime