You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ha...@apache.org on 2020/06/17 17:05:46 UTC
[incubator-tvm] branch master updated: [Frontend][TensorFlow]Fix TF
Dynamic input shape (#5825)
This is an automated email from the ASF dual-hosted git repository.
haichen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 5e28bcd [Frontend][TensorFlow]Fix TF Dynamic input shape (#5825)
5e28bcd is described below
commit 5e28bcdd9a0ed7e230dc6dee5a7e50a580e1f148
Author: Yao Wang <ke...@gmail.com>
AuthorDate: Wed Jun 17 10:05:35 2020 -0700
[Frontend][TensorFlow]Fix TF Dynamic input shape (#5825)
* Fix TF Dynamic input shape
* Remove warning
* Add test
---
python/tvm/relay/frontend/tensorflow.py | 4 +--
tests/python/frontend/tensorflow/test_forward.py | 38 ++++++++++++++++++++++--
2 files changed, 37 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index af09877..62dadce 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -2824,9 +2824,7 @@ class GraphProto(object):
tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
for idx, dim in enumerate(self._input_shapes[node.name]):
if dim < 0:
- self._input_shapes[node.name][idx] = 1
- warnings.warn("Use 1 instead of -1 in shape of operator %s."
- % node.name)
+ self._input_shapes[node.name][idx] = Any()
self._output_shapes[node.name] = [self._input_shapes[node.name]]
attr = self._parse_attr(node.attr)
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 6f3b7f4..1a0baf8 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -100,14 +100,18 @@ def vmobj_to_list(o):
def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
target='llvm', out_names=None, opt_level=3, mode='graph_runtime',
- cuda_layout="NCHW", layout=None, disabled_pass=None):
+ cuda_layout="NCHW", layout=None, disabled_pass=None, ignore_in_shape=False):
""" Generic function to compile on relay and execute on tvm """
input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)
if target == "cuda":
layout = cuda_layout
target_host = None
- shape_dict = {e: i.shape if hasattr(i, "shape") else () for e, i in zip(input_node, input_data)}
+ if ignore_in_shape:
+ shape_dict = None
+ else:
+ shape_dict = {e: i.shape if hasattr(i, "shape") else ()
+ for e, i in zip(input_node, input_data)}
mod, params = relay.frontend.from_tensorflow(graph_def,
layout=layout,
shape=shape_dict,
@@ -3715,6 +3719,33 @@ def test_forward_spop():
_test_spop_variables()
_test_spop_constants()
+#######################################################################
+# Dynamic input shape
+# -------------------
+def test_forward_dynamic_input_shape():
+ tf.reset_default_graph()
+
+ with tf.Graph().as_default():
+ data = tf.placeholder(tf.float32, name='data', shape=(None,))
+ out = data + 1
+ np_data = np.random.uniform(size=(2,)).astype("float32")
+ out_name = "add"
+
+ with tf.Session() as sess:
+ graph_def = tf_testing.AddShapesToGraphDef(sess, out_name)
+ tf_output = run_tf_graph(sess, np_data, 'data:0', ['{}:0'.format(out_name)])
+ # TODO(kevinthesun): enable gpu test when VM heterogeneous execution is ready.
+ for device in ["llvm"]:
+ ctx = tvm.context(device, 0)
+ if not ctx.exist:
+ print("Skip because %s is not enabled" % device)
+ continue
+ tvm_output = run_tvm_graph(graph_def, np_data, ["data"], 1,
+ target=device, layout="NCHW", out_names=[out_name],
+ mode="vm", ignore_in_shape=True)
+ tvm.testing.assert_allclose(tvm_output[0], tf_output[0],
+ rtol=1e-5, atol=1e-5)
+
#######################################################################
# Main
@@ -3851,3 +3882,6 @@ if __name__ == '__main__':
# StatefulPartitionedCall
test_forward_spop()
+
+ # Test dynamic input shape
+ test_forward_dynamic_input_shape()