You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/03/06 09:28:10 UTC

[tvm] branch main updated: [Executor][Bugfix] Properly return and unflatten outputs from GraphExecutor (#7604)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 875f8ee  [Executor][Bugfix] Properly return and unflatten outputs from GraphExecutor (#7604)
875f8ee is described below

commit 875f8ee2a704d1de6d6a681ac2c7b7073b73e79c
Author: Altan Haan <ah...@octoml.ai>
AuthorDate: Sat Mar 6 01:27:49 2021 -0800

    [Executor][Bugfix] Properly return and unflatten outputs from GraphExecutor (#7604)
    
    * properly return and unflatten outputs from GraphExecutor
    
    * lint
    
    * cleaner approach, not sure what I was thinking before
    
    * remove unused import
    
    * forgot copyto cpu
    
    * make solution even cleaner using iterator
---
 python/tvm/relay/build_module.py                 | 24 ++++++++++++++++--------
 tests/python/relay/test_backend_graph_runtime.py | 21 +++++++++++++++++++++
 2 files changed, 37 insertions(+), 8 deletions(-)

diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index 79eb7e4..4c9a898 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -391,10 +391,20 @@ class GraphExecutor(_interpreter.Executor):
         ret_type = self.mod["main"].checked_type.ret_type
         if _ty.is_dynamic(ret_type):
             raise ValueError("Graph Runtime only supports static graphs, got output type", ret_type)
-        num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
         mod = build(self.mod, target=self.target)
         gmodule = _graph_rt.GraphModule(mod["default"](self.ctx))
 
+        def _unflatten(flat_iter, cur_type):
+            if isinstance(cur_type, _ty.TensorType):
+                return next(flat_iter)
+            if isinstance(cur_type, _ty.TupleType):
+                fields = []
+                for field_type in cur_type.fields:
+                    field = _unflatten(flat_iter, field_type)
+                    fields.append(field)
+                return fields
+            raise ValueError("Return type", ret_type, "contains unsupported type", cur_type)
+
         def _graph_wrapper(*args, **kwargs):
             args = self._convert_args(self.mod["main"], args, kwargs)
             # Create map of inputs.
@@ -402,13 +412,11 @@ class GraphExecutor(_interpreter.Executor):
                 gmodule.set_input(i, arg)
             # Run the module, and fetch the output.
             gmodule.run()
-            # make a copy so multiple invocation won't hurt perf.
-            if num_outputs == 1:
-                return gmodule.get_output(0).copyto(_nd.cpu(0))
-            outputs = []
-            for i in range(num_outputs):
-                outputs.append(gmodule.get_output(i).copyto(_nd.cpu(0)))
-            return outputs
+            flattened = []
+            for i in range(gmodule.get_num_outputs()):
+                flattened.append(gmodule.get_output(i).copyto(_nd.cpu(0)))
+            unflattened = _unflatten(iter(flattened), ret_type)
+            return unflattened
 
         return _graph_wrapper
 
diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py
index 3c42b7b..68708aa 100644
--- a/tests/python/relay/test_backend_graph_runtime.py
+++ b/tests/python/relay/test_backend_graph_runtime.py
@@ -209,6 +209,27 @@ def test_compile_nested_tuples():
         ref = ref + 1
 
 
+def test_graph_executor_nested_tuples():
+    x, y, z, w = [relay.var(c, shape=(2, 3), dtype="float32") for c in "xyzw"]
+    out = relay.Tuple([x, relay.Tuple([y, relay.Tuple([z, w])])])
+    func = relay.Function([x, y, z, w], out)
+
+    exe = relay.create_executor(
+        kind="graph", mod=tvm.IRModule.from_expr(func), ctx=tvm.cpu(0), target="llvm"
+    )
+    f = exe.evaluate()
+
+    data = [np.random.uniform(size=(2, 3)).astype("float32") for _ in "xyzw"]
+    out = f(*data)
+    assert len(out) == 2
+    tvm.testing.assert_allclose(out[0].asnumpy(), data[0])
+    assert len(out[1]) == 2
+    tvm.testing.assert_allclose(out[1][0].asnumpy(), data[1])
+    assert len(out[1][1]) == 2
+    tvm.testing.assert_allclose(out[1][1][0].asnumpy(), data[2])
+    tvm.testing.assert_allclose(out[1][1][1].asnumpy(), data[3])
+
+
 if __name__ == "__main__":
     test_plan_memory()
     test_with_params()