You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ar...@apache.org on 2021/05/05 16:58:06 UTC
[tvm] branch main updated: [FIX,
VM] Fix get_outputs on the vm with a single output (#7902)
This is an automated email from the ASF dual-hosted git repository.
areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3f788b4 [FIX,VM] Fix get_outputs on the vm with a single output (#7902)
3f788b4 is described below
commit 3f788b41c1a93bf929a475f182a1fd0fc9f9f142
Author: Tristan Konolige <tr...@gmail.com>
AuthorDate: Wed May 5 09:57:37 2021 -0700
[FIX,VM] Fix get_outputs on the vm with a single output (#7902)
* [FIX,VM] Fix get_outputs on the vm with a single output
The VM uses an ADT for multiple outputs and an NDArray for a single
output. The single output case was not being handled.
* check if the user specified the correct index
---
src/runtime/vm/vm.cc | 18 +++++++++++++++---
tests/python/relay/test_vm.py | 37 +++++++++++++++++++++++++++++++++++++
2 files changed, 52 insertions(+), 3 deletions(-)
diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index a0edb3b..17a66e4 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -142,11 +142,23 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
});
} else if (name == "get_output") {
return TypedPackedFunc<NDArray(int64_t)>([this](int64_t index) {
- return Downcast<NDArray>(Downcast<ADT>(this->return_register_)[index]);
+ if (this->return_register_.as<ADTObj>()) {
+ return Downcast<NDArray>(Downcast<ADT>(this->return_register_)[index]);
+ } else {
+ CHECK_EQ(index, 0) << "VM output contains only one item, but you are trying to get the "
+ << index << "th.";
+ return Downcast<NDArray>(this->return_register_);
+ }
});
} else if (name == "get_num_outputs") {
- return TypedPackedFunc<int64_t(void)>(
- [this]() -> int64_t { return Downcast<ADT>(this->return_register_).size(); });
+ return TypedPackedFunc<int64_t(void)>([this]() -> int64_t {
+ // single output is an NDArray not an ADT
+ if (this->return_register_.as<ADTObj>()) {
+ return Downcast<ADT>(this->return_register_).size();
+ } else {
+ return 1;
+ }
+ });
} else if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size() % 3, 0);
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index 7e79049..8f51869 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -852,5 +852,42 @@ def test_vm_rpc():
server.terminate()
+def test_get_output_single():
+ target = tvm.target.Target("llvm")
+
+ # Build a IRModule.
+ x = relay.var("x", shape=(10,))
+ f = relay.Function([x], x + x)
+ mod = IRModule.from_expr(f)
+
+ # Compile to VMExecutable.
+ vm_exec = vm.compile(mod, target=target)
+ vm_factory = runtime.vm.VirtualMachine(vm_exec, tvm.cpu())
+ inp = np.ones(10, dtype="float32")
+ vm_factory.invoke_stateful("main", inp)
+ outputs = vm_factory.get_outputs()
+ assert len(outputs) == 1
+ np.testing.assert_allclose(outputs[0].asnumpy(), inp + inp)
+
+
+def test_get_output_multiple():
+ target = tvm.target.Target("llvm")
+
+ # Build a IRModule.
+ x = relay.var("x", shape=(10,))
+ f = relay.Function([x], relay.Tuple([x + x, x]))
+ mod = IRModule.from_expr(f)
+
+ # Compile to VMExecutable.
+ vm_exec = vm.compile(mod, target=target)
+ vm_factory = runtime.vm.VirtualMachine(vm_exec, tvm.cpu())
+ inp = np.ones(10, dtype="float32")
+ vm_factory.invoke_stateful("main", inp)
+ outputs = vm_factory.get_outputs()
+ assert len(outputs) == 2
+ np.testing.assert_allclose(outputs[0].asnumpy(), inp + inp)
+ np.testing.assert_allclose(outputs[1].asnumpy(), inp)
+
+
if __name__ == "__main__":
pytest.main([__file__])