You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/12/15 11:56:29 UTC

[tvm] branch main updated: [ONNX] Fix a bug with reshape imports when an initialized target shape is used more than once (#7109)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new faed409  [ONNX] Fix a bug with reshape imports when an initialized target shape is used more than once (#7109)
faed409 is described below

commit faed4096536bd86be5731a113d7b34ecd55ddc72
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Tue Dec 15 04:56:09 2020 -0700

    [ONNX] Fix a bug with reshape imports when an initialized target shape is used more than once (#7109)
    
    * Fix a bug with reshape imports when an initialized target shape is used more than once
    
    * run autoformat
---
 python/tvm/relay/frontend/onnx.py          |  3 +--
 tests/python/frontend/onnx/test_forward.py | 37 +++++++++++++++++++++++++++++-
 2 files changed, 37 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 23102aa..cbec322 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -859,8 +859,7 @@ class Reshape(OnnxOpConverter):
     @classmethod
     def _impl_v5(cls, inputs, attr, params):
         if get_name(inputs[1]) in params:
-            # pop shape out of parameters since it wont be needed later.
-            shape = tuple(params.pop(inputs[1].name_hint).asnumpy().astype("int32"))
+            shape = tuple(params[inputs[1].name_hint].asnumpy().astype("int32"))
             out = _op.reshape(inputs[0], shape)
         else:
             out = _op.reshape(*inputs)
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index bae50c9..33dd048 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -73,7 +73,6 @@ def get_tvm_output(
     input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data)
 
     mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)
-
     with tvm.transform.PassContext(opt_level=1):
         graph, lib, params = relay.build(mod, target, params=params)
 
@@ -234,6 +233,42 @@ def test_reshape():
         tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
 
 
+@tvm.testing.uses_gpu
+def test_double_reshape():
+    in_shape = (4, 3, 3, 4)
+    ref_shape = (6, 2, 4, 3)
+
+    ref_array = np.array(ref_shape)
+    ref_node = onnx.helper.make_node(
+        "Constant",
+        inputs=[],
+        outputs=["ref_in"],
+        value=onnx.helper.make_tensor(
+            name="const_tensor",
+            data_type=onnx.TensorProto.INT32,
+            dims=ref_array.shape,
+            vals=ref_array.flatten().astype(int),
+        ),
+    )
+    reshape_node1 = helper.make_node("Reshape", ["in", "ref_in"], ["out1"])
+    reshape_node2 = helper.make_node("Reshape", ["in", "ref_in"], ["out2"])
+    add_node = helper.make_node("Add", ["out1", "out2"], ["out"])
+
+    graph = helper.make_graph(
+        [ref_node, reshape_node1, reshape_node2, add_node],
+        "reshape_test",
+        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_shape))],
+    )
+
+    model = helper.make_model(graph, producer_name="reshape_test")
+
+    for target, ctx in tvm.testing.enabled_targets():
+        x = np.random.uniform(size=in_shape).astype("int32")
+        tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "float32")
+        tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
+
+
 # TODO(mbrookhart): enable once VM supports heterogenous execution
 # @tvm.testing.uses_gpu
 def test_expand():