You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ar...@apache.org on 2021/05/05 16:58:06 UTC

[tvm] branch main updated: [FIX, VM] Fix get_outputs on the vm with a single output (#7902)

This is an automated email from the ASF dual-hosted git repository.

areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3f788b4  [FIX,VM] Fix get_outputs on the vm with a single output (#7902)
3f788b4 is described below

commit 3f788b41c1a93bf929a475f182a1fd0fc9f9f142
Author: Tristan Konolige <tr...@gmail.com>
AuthorDate: Wed May 5 09:57:37 2021 -0700

    [FIX,VM] Fix get_outputs on the vm with a single output (#7902)
    
    * [FIX,VM] Fix get_outputs on the vm with a single output
    
    The VM uses an ADT for multiple outputs and an NDArray for a single
    output. The single output case was not being handled.
    
    * check if the user specified the correct index
---
 src/runtime/vm/vm.cc          | 18 +++++++++++++++---
 tests/python/relay/test_vm.py | 37 +++++++++++++++++++++++++++++++++++++
 2 files changed, 52 insertions(+), 3 deletions(-)

diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index a0edb3b..17a66e4 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -142,11 +142,23 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
     });
   } else if (name == "get_output") {
     return TypedPackedFunc<NDArray(int64_t)>([this](int64_t index) {
-      return Downcast<NDArray>(Downcast<ADT>(this->return_register_)[index]);
+      if (this->return_register_.as<ADTObj>()) {
+        return Downcast<NDArray>(Downcast<ADT>(this->return_register_)[index]);
+      } else {
+        CHECK_EQ(index, 0) << "VM output contains only one item, but you are trying to get the "
+                           << index << "th.";
+        return Downcast<NDArray>(this->return_register_);
+      }
     });
   } else if (name == "get_num_outputs") {
-    return TypedPackedFunc<int64_t(void)>(
-        [this]() -> int64_t { return Downcast<ADT>(this->return_register_).size(); });
+    return TypedPackedFunc<int64_t(void)>([this]() -> int64_t {
+      // single output is an NDArray not an ADT
+      if (this->return_register_.as<ADTObj>()) {
+        return Downcast<ADT>(this->return_register_).size();
+      } else {
+        return 1;
+      }
+    });
   } else if (name == "init") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
       ICHECK_EQ(args.size() % 3, 0);
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index 7e79049..8f51869 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -852,5 +852,42 @@ def test_vm_rpc():
     server.terminate()
 
 
+def test_get_output_single():
+    target = tvm.target.Target("llvm")
+
+    # Build a IRModule.
+    x = relay.var("x", shape=(10,))
+    f = relay.Function([x], x + x)
+    mod = IRModule.from_expr(f)
+
+    # Compile to VMExecutable.
+    vm_exec = vm.compile(mod, target=target)
+    vm_factory = runtime.vm.VirtualMachine(vm_exec, tvm.cpu())
+    inp = np.ones(10, dtype="float32")
+    vm_factory.invoke_stateful("main", inp)
+    outputs = vm_factory.get_outputs()
+    assert len(outputs) == 1
+    np.testing.assert_allclose(outputs[0].asnumpy(), inp + inp)
+
+
+def test_get_output_multiple():
+    target = tvm.target.Target("llvm")
+
+    # Build a IRModule.
+    x = relay.var("x", shape=(10,))
+    f = relay.Function([x], relay.Tuple([x + x, x]))
+    mod = IRModule.from_expr(f)
+
+    # Compile to VMExecutable.
+    vm_exec = vm.compile(mod, target=target)
+    vm_factory = runtime.vm.VirtualMachine(vm_exec, tvm.cpu())
+    inp = np.ones(10, dtype="float32")
+    vm_factory.invoke_stateful("main", inp)
+    outputs = vm_factory.get_outputs()
+    assert len(outputs) == 2
+    np.testing.assert_allclose(outputs[0].asnumpy(), inp + inp)
+    np.testing.assert_allclose(outputs[1].asnumpy(), inp)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])