You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ha...@apache.org on 2020/06/17 17:05:46 UTC

[incubator-tvm] branch master updated: [Frontend][TensorFlow]Fix TF Dynamic input shape (#5825)

This is an automated email from the ASF dual-hosted git repository.

haichen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 5e28bcd  [Frontend][TensorFlow]Fix TF Dynamic input shape (#5825)
5e28bcd is described below

commit 5e28bcdd9a0ed7e230dc6dee5a7e50a580e1f148
Author: Yao Wang <ke...@gmail.com>
AuthorDate: Wed Jun 17 10:05:35 2020 -0700

    [Frontend][TensorFlow]Fix TF Dynamic input shape (#5825)
    
    * Fix TF Dynamic input shape
    
    * Remove warning
    
    * Add test
---
 python/tvm/relay/frontend/tensorflow.py          |  4 +--
 tests/python/frontend/tensorflow/test_forward.py | 38 ++++++++++++++++++++++--
 2 files changed, 37 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index af09877..62dadce 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -2824,9 +2824,7 @@ class GraphProto(object):
                         tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
                     for idx, dim in enumerate(self._input_shapes[node.name]):
                         if dim < 0:
-                            self._input_shapes[node.name][idx] = 1
-                            warnings.warn("Use 1 instead of -1 in shape of operator %s."
-                                          % node.name)
+                            self._input_shapes[node.name][idx] = Any()
 
                 self._output_shapes[node.name] = [self._input_shapes[node.name]]
                 attr = self._parse_attr(node.attr)
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 6f3b7f4..1a0baf8 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -100,14 +100,18 @@ def vmobj_to_list(o):
 
 def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
                   target='llvm', out_names=None, opt_level=3, mode='graph_runtime',
-                  cuda_layout="NCHW", layout=None, disabled_pass=None):
+                  cuda_layout="NCHW", layout=None, disabled_pass=None, ignore_in_shape=False):
     """ Generic function to compile on relay and execute on tvm """
     input_data = convert_to_list(input_data)
     input_node = convert_to_list(input_node)
     if target == "cuda":
         layout = cuda_layout
     target_host = None
-    shape_dict = {e: i.shape if hasattr(i, "shape") else () for e, i in zip(input_node, input_data)}
+    if ignore_in_shape:
+        shape_dict = None
+    else:
+        shape_dict = {e: i.shape if hasattr(i, "shape") else ()
+                      for e, i in zip(input_node, input_data)}
     mod, params = relay.frontend.from_tensorflow(graph_def,
                                                  layout=layout,
                                                  shape=shape_dict,
@@ -3715,6 +3719,33 @@ def test_forward_spop():
     _test_spop_variables()
     _test_spop_constants()
 
+#######################################################################
+# Dynamic input shape
+# -------------------
+def test_forward_dynamic_input_shape():
+    tf.reset_default_graph()
+
+    with tf.Graph().as_default():
+        data = tf.placeholder(tf.float32, name='data', shape=(None,))
+        out = data + 1
+        np_data = np.random.uniform(size=(2,)).astype("float32")
+        out_name = "add"
+
+        with tf.Session() as sess:
+            graph_def = tf_testing.AddShapesToGraphDef(sess, out_name)
+            tf_output = run_tf_graph(sess, np_data, 'data:0', ['{}:0'.format(out_name)])
+            # TODO(kevinthesun): enable gpu test when VM heterogeneous execution is ready.
+            for device in ["llvm"]:
+                ctx = tvm.context(device, 0)
+                if not ctx.exist:
+                    print("Skip because %s is not enabled" % device)
+                    continue
+                tvm_output = run_tvm_graph(graph_def, np_data, ["data"], 1,
+                                           target=device, layout="NCHW", out_names=[out_name],
+                                           mode="vm", ignore_in_shape=True)
+                tvm.testing.assert_allclose(tvm_output[0], tf_output[0],
+                                            rtol=1e-5, atol=1e-5)
+
 
 #######################################################################
 # Main
@@ -3851,3 +3882,6 @@ if __name__ == '__main__':
 
     # StatefulPartitionedCall
     test_forward_spop()
+
+    # Test dynamic input shape
+    test_forward_dynamic_input_shape()