You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/04/10 14:48:03 UTC
[incubator-tvm] branch master updated: [FRONTEND][TENSORFLOW] Fix
gather_nd indices (#5279)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 0d1babc [FRONTEND][TENSORFLOW] Fix gather_nd indices (#5279)
0d1babc is described below
commit 0d1babce46973f3c4662d11d429fe4e38c68eede
Author: MORITA Kazutaka <mo...@gmail.com>
AuthorDate: Fri Apr 10 23:47:53 2020 +0900
[FRONTEND][TENSORFLOW] Fix gather_nd indices (#5279)
* [FRONTEND][TENSORFLOW] Fix gather_nd indices
* retrigger CI
---
python/tvm/relay/frontend/tensorflow.py | 4 +++-
tests/python/frontend/tensorflow/test_forward.py | 6 +++---
2 files changed, 6 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index 77dbcb5..8a72423 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1127,9 +1127,11 @@ def _gather():
def _gather_nd():
"""GatherNd"""
def _impl(inputs, attr, params, mod):
+ indices_dims = len(_infer_shape(inputs[1], mod))
+ indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims-1)))
return AttrCvt(op_name="gather_nd",
ignores=['Tindices', 'Tparams',\
- 'Taxis', '_class'])(inputs, attr)
+ 'Taxis', '_class'])([inputs[0], indices], attr)
return _impl
def _stridedSlice():
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 35a3466..fdb8912 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -1365,11 +1365,11 @@ def test_forward_gather():
def test_forward_gather_nd():
"""test operator GatherNd"""
- np_data = np.random.uniform(1, 100, size=(2, 2)).astype(np.float32)
+ np_data = np.random.uniform(1, 100, size=(2, 2, 2)).astype(np.float32)
tf.reset_default_graph()
with tf.Graph().as_default():
- in_data = tf.placeholder(tf.float32, (2, 2), name="in_data")
- tf.gather_nd(in_data, indices=[[1, 0], [0, 1]], name="gather_nd")
+ in_data = tf.placeholder(tf.float32, (2, 2, 2), name="in_data")
+ tf.gather_nd(in_data, indices=[[1, 0, 0], [0, 0, 0]], name="gather_nd")
compare_tf_with_tvm([np_data], ['in_data:0'], 'gather_nd:0')