You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2024/03/03 14:17:52 UTC

(tvm) branch main updated: [Relax][Runtime] Support Unpack API for NDArrayCache (#16648)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5718ff35ef [Relax][Runtime] Support Unpack API for NDArrayCache (#16648)
5718ff35ef is described below

commit 5718ff35ef5ba758a325cdbca191f36b84d0b549
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Sun Mar 3 22:17:46 2024 +0800

    [Relax][Runtime] Support Unpack API for NDArrayCache (#16648)
    
    As `Array` cannot be transferred through RPC protocol, we introduce
    a new unpack API by directly passing all str through PackedFunc.
    
    This PR also fixes a bug in `vm.builtin.ndarray_cache.update`
---
 src/runtime/relax_vm/ndarray_cache_support.cc | 15 ++++++++++++++-
 1 file changed, 14 insertions(+), 1 deletion(-)

diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc b/src/runtime/relax_vm/ndarray_cache_support.cc
index b389030cfe..fce40157e4 100644
--- a/src/runtime/relax_vm/ndarray_cache_support.cc
+++ b/src/runtime/relax_vm/ndarray_cache_support.cc
@@ -282,7 +282,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.update").set_body([](TVMArgs args,
     for (int64_t i = 0; i < tensor->ndim; i++) {
       shape.push_back(tensor->shape[i]);
     }
-    NDArray arr = NDArray::Empty(shape, tensor->dtype, tensor->device);
+    arr = NDArray::Empty(shape, tensor->dtype, tensor->device);
     arr.CopyFrom(tensor);
     TVMSynchronize(arr->device.device_type, arr->device.device_id, nullptr);
   }
@@ -358,6 +358,19 @@ TVM_REGISTER_GLOBAL("vm.builtin.param_module_from_cache_by_name")
 TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache").set_body_typed(ParamModuleNode::GetParams);
 TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name")
     .set_body_typed(ParamModuleNode::GetParamByName);
+TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name_unpacked")
+    .set_body([](TVMArgs args, TVMRetValue* rv) {
+      Array<String> names;
+      names.reserve(args.size());
+      for (int i = 0; i < args.size(); ++i) {
+        if (args[i].type_code() != kTVMStr) {
+          LOG(FATAL) << "ValueError: Expect string as input, but get " << args[i].type_code()
+                     << " at " << i;
+        }
+        names.push_back(args[i]);
+      }
+      *rv = ParamModuleNode::GetParamByName(names);
+    });
 
 }  // namespace relax_vm
 }  // namespace runtime