You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2022/01/10 19:07:02 UTC
[tvm] branch main updated: [VM] Remove undesired arg to load_late_bound_consts (#9870)
This is an automated email from the ASF dual-hosted git repository.
jwfromm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3b85f7c [VM] Remove undesired arg to load_late_bound_consts (#9870)
3b85f7c is described below
commit 3b85f7c7680d43ff9580a955eeafe97af0bceec7
Author: Michal Piszczek <im...@gmail.com>
AuthorDate: Mon Jan 10 11:06:17 2022 -0800
[VM] Remove undesired arg to load_late_bound_consts (#9870)
* Remove undesired arg to vm exec load_late_bound_consts
* No-op for ci
---
python/tvm/runtime/vm.py | 2 +-
tests/python/relay/test_vm.py | 14 ++++++++++++++
2 files changed, 15 insertions(+), 1 deletion(-)
diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py
index d9cab84..5395326 100644
--- a/python/tvm/runtime/vm.py
+++ b/python/tvm/runtime/vm.py
@@ -306,7 +306,7 @@ class Executable(object):
def load_late_bound_consts(self, path):
"""Re-load constants previously saved to file at path"""
- return self._load_late_bound_consts(path, bytes)
+ return self._load_late_bound_consts(path)
class VirtualMachine(object):
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index 5e07511..0740943 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -1159,6 +1159,20 @@ def test_large_constants():
expected = x_data + const_data
tvm.testing.assert_allclose(expected, actual.numpy())
+ # We load the mod again so it's missing the consts.
+ mod = runtime.load_module(path_dso)
+ exe = runtime.vm.Executable(mod)
+
+ # Also test loading consts via the VM's wrapper API.
+ exe.load_late_bound_consts(path_consts)
+
+ # Test main again with consts now loaded via the above API.
+ x_data = np.random.rand(1000, 1000).astype("float32")
+ the_vm = runtime.vm.VirtualMachine(exe, dev)
+ actual = the_vm.invoke("main", x_data)
+ expected = x_data + const_data
+ tvm.testing.assert_allclose(expected, actual.numpy())
+
if __name__ == "__main__":
import sys