You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/03/06 09:28:10 UTC
[tvm] branch main updated: [Executor][Bugfix] Properly return and
unflatten outputs from GraphExecutor (#7604)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 875f8ee [Executor][Bugfix] Properly return and unflatten outputs from GraphExecutor (#7604)
875f8ee is described below
commit 875f8ee2a704d1de6d6a681ac2c7b7073b73e79c
Author: Altan Haan <ah...@octoml.ai>
AuthorDate: Sat Mar 6 01:27:49 2021 -0800
[Executor][Bugfix] Properly return and unflatten outputs from GraphExecutor (#7604)
* properly return and unflatten outputs from GraphExecutor
* lint
* cleaner approach, not sure what I was thinking before
* remove unused import
* forgot copyto cpu
* make solution even cleaner using iterator
---
python/tvm/relay/build_module.py | 24 ++++++++++++++++--------
tests/python/relay/test_backend_graph_runtime.py | 21 +++++++++++++++++++++
2 files changed, 37 insertions(+), 8 deletions(-)
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index 79eb7e4..4c9a898 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -391,10 +391,20 @@ class GraphExecutor(_interpreter.Executor):
ret_type = self.mod["main"].checked_type.ret_type
if _ty.is_dynamic(ret_type):
raise ValueError("Graph Runtime only supports static graphs, got output type", ret_type)
- num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
mod = build(self.mod, target=self.target)
gmodule = _graph_rt.GraphModule(mod["default"](self.ctx))
+ def _unflatten(flat_iter, cur_type):
+ if isinstance(cur_type, _ty.TensorType):
+ return next(flat_iter)
+ if isinstance(cur_type, _ty.TupleType):
+ fields = []
+ for field_type in cur_type.fields:
+ field = _unflatten(flat_iter, field_type)
+ fields.append(field)
+ return fields
+ raise ValueError("Return type", ret_type, "contains unsupported type", cur_type)
+
def _graph_wrapper(*args, **kwargs):
args = self._convert_args(self.mod["main"], args, kwargs)
# Create map of inputs.
@@ -402,13 +412,11 @@ class GraphExecutor(_interpreter.Executor):
gmodule.set_input(i, arg)
# Run the module, and fetch the output.
gmodule.run()
- # make a copy so multiple invocation won't hurt perf.
- if num_outputs == 1:
- return gmodule.get_output(0).copyto(_nd.cpu(0))
- outputs = []
- for i in range(num_outputs):
- outputs.append(gmodule.get_output(i).copyto(_nd.cpu(0)))
- return outputs
+ flattened = []
+ for i in range(gmodule.get_num_outputs()):
+ flattened.append(gmodule.get_output(i).copyto(_nd.cpu(0)))
+ unflattened = _unflatten(iter(flattened), ret_type)
+ return unflattened
return _graph_wrapper
diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py
index 3c42b7b..68708aa 100644
--- a/tests/python/relay/test_backend_graph_runtime.py
+++ b/tests/python/relay/test_backend_graph_runtime.py
@@ -209,6 +209,27 @@ def test_compile_nested_tuples():
ref = ref + 1
+def test_graph_executor_nested_tuples():
+ x, y, z, w = [relay.var(c, shape=(2, 3), dtype="float32") for c in "xyzw"]
+ out = relay.Tuple([x, relay.Tuple([y, relay.Tuple([z, w])])])
+ func = relay.Function([x, y, z, w], out)
+
+ exe = relay.create_executor(
+ kind="graph", mod=tvm.IRModule.from_expr(func), ctx=tvm.cpu(0), target="llvm"
+ )
+ f = exe.evaluate()
+
+ data = [np.random.uniform(size=(2, 3)).astype("float32") for _ in "xyzw"]
+ out = f(*data)
+ assert len(out) == 2
+ tvm.testing.assert_allclose(out[0].asnumpy(), data[0])
+ assert len(out[1]) == 2
+ tvm.testing.assert_allclose(out[1][0].asnumpy(), data[1])
+ assert len(out[1][1]) == 2
+ tvm.testing.assert_allclose(out[1][1][0].asnumpy(), data[2])
+ tvm.testing.assert_allclose(out[1][1][1].asnumpy(), data[3])
+
+
if __name__ == "__main__":
test_plan_memory()
test_with_params()