You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2022/01/10 19:07:02 UTC

[tvm] branch main updated: [VM] Remove undesired arg to load_late_bound_consts (#9870)

This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3b85f7c  [VM] Remove undesired arg to load_late_bound_consts (#9870)
3b85f7c is described below

commit 3b85f7c7680d43ff9580a955eeafe97af0bceec7
Author: Michal Piszczek <im...@gmail.com>
AuthorDate: Mon Jan 10 11:06:17 2022 -0800

    [VM] Remove undesired arg to load_late_bound_consts (#9870)
    
    * Remove undesired arg to vm exec load_late_bound_consts
    
    * No-op for ci
---
 python/tvm/runtime/vm.py      |  2 +-
 tests/python/relay/test_vm.py | 14 ++++++++++++++
 2 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py
index d9cab84..5395326 100644
--- a/python/tvm/runtime/vm.py
+++ b/python/tvm/runtime/vm.py
@@ -306,7 +306,7 @@ class Executable(object):
 
     def load_late_bound_consts(self, path):
         """Re-load constants previously saved to file at path"""
-        return self._load_late_bound_consts(path, bytes)
+        return self._load_late_bound_consts(path)
 
 
 class VirtualMachine(object):
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index 5e07511..0740943 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -1159,6 +1159,20 @@ def test_large_constants():
     expected = x_data + const_data
     tvm.testing.assert_allclose(expected, actual.numpy())
 
+    # We load the mod again so it's missing the consts.
+    mod = runtime.load_module(path_dso)
+    exe = runtime.vm.Executable(mod)
+
+    # Also test loading consts via the VM's wrapper API.
+    exe.load_late_bound_consts(path_consts)
+
+    # Test main again with consts now loaded via the above API.
+    x_data = np.random.rand(1000, 1000).astype("float32")
+    the_vm = runtime.vm.VirtualMachine(exe, dev)
+    actual = the_vm.invoke("main", x_data)
+    expected = x_data + const_data
+    tvm.testing.assert_allclose(expected, actual.numpy())
+
 
 if __name__ == "__main__":
     import sys