You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/11/29 00:38:25 UTC

[tvm] branch main updated: Fix GraphRuntime with -link-params over RPC (#6985)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 94ce835  Fix GraphRuntime with -link-params over RPC (#6985)
94ce835 is described below

commit 94ce8353502e9b9e183bb8a61fda6108713558b8
Author: Andrew Reusch <ar...@octoml.ai>
AuthorDate: Sat Nov 28 16:38:08 2020 -0800

    Fix GraphRuntime with -link-params over RPC (#6985)
    
    * Fix GraphRuntime with remotely-linked params.
    
     * Previous test did not exercise this correctly.
    
    * fix incorrect function name
---
 python/tvm/micro/session.py               | 2 +-
 tests/python/unittest/test_link_params.py | 5 +++--
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py
index fba612b..1f91cdd 100644
--- a/python/tvm/micro/session.py
+++ b/python/tvm/micro/session.py
@@ -187,7 +187,7 @@ def lookup_remote_linked_param(mod, storage_id, template_tensor, ctx):
         return None
 
     return get_global_func("tvm.rpc.NDArrayFromRemoteOpaqueHandle")(
-        mod, remote_data, template_tensor, ctx, lambda: None
+        mod, remote_data, template_tensor, ctx, None
     )
 
 
diff --git a/tests/python/unittest/test_link_params.py b/tests/python/unittest/test_link_params.py
index 7b6910b..c3c2232 100644
--- a/tests/python/unittest/test_link_params.py
+++ b/tests/python/unittest/test_link_params.py
@@ -378,8 +378,9 @@ def test_crt_link_params():
             }
             flasher = compiler.flasher(**flasher_kw)
             with tvm.micro.Session(binary=micro_binary, flasher=flasher) as sess:
-                rpc_lib = sess.get_system_lib()
-                graph_rt = tvm.contrib.graph_runtime.create(graph_json, rpc_lib, sess.context)
+                graph_rt = tvm.micro.session.create_local_graph_runtime(
+                    graph_json, sess.get_system_lib(), sess.context
+                )
 
                 # NOTE: not setting params here.
                 graph_rt.set_input("rand_input", rand_input)