You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/04/10 14:48:03 UTC

[incubator-tvm] branch master updated: [FRONTEND][TENSORFLOW] Fix gather_nd indices (#5279)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 0d1babc  [FRONTEND][TENSORFLOW] Fix gather_nd indices (#5279)
0d1babc is described below

commit 0d1babce46973f3c4662d11d429fe4e38c68eede
Author: MORITA Kazutaka <mo...@gmail.com>
AuthorDate: Fri Apr 10 23:47:53 2020 +0900

    [FRONTEND][TENSORFLOW] Fix gather_nd indices (#5279)
    
    * [FRONTEND][TENSORFLOW] Fix gather_nd indices
    
    * retrigger CI
---
 python/tvm/relay/frontend/tensorflow.py          | 4 +++-
 tests/python/frontend/tensorflow/test_forward.py | 6 +++---
 2 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index 77dbcb5..8a72423 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1127,9 +1127,11 @@ def _gather():
 def _gather_nd():
     """GatherNd"""
     def _impl(inputs, attr, params, mod):
+        indices_dims = len(_infer_shape(inputs[1], mod))
+        indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims-1)))
         return AttrCvt(op_name="gather_nd",
                        ignores=['Tindices', 'Tparams',\
-                                'Taxis', '_class'])(inputs, attr)
+                                'Taxis', '_class'])([inputs[0], indices], attr)
     return _impl
 
 def _stridedSlice():
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 35a3466..fdb8912 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -1365,11 +1365,11 @@ def test_forward_gather():
 
 def test_forward_gather_nd():
     """test operator GatherNd"""
-    np_data = np.random.uniform(1, 100, size=(2, 2)).astype(np.float32)
+    np_data = np.random.uniform(1, 100, size=(2, 2, 2)).astype(np.float32)
     tf.reset_default_graph()
     with tf.Graph().as_default():
-        in_data = tf.placeholder(tf.float32, (2, 2), name="in_data")
-        tf.gather_nd(in_data, indices=[[1, 0], [0, 1]], name="gather_nd")
+        in_data = tf.placeholder(tf.float32, (2, 2, 2), name="in_data")
+        tf.gather_nd(in_data, indices=[[1, 0, 0], [0, 0, 0]], name="gather_nd")
         compare_tf_with_tvm([np_data], ['in_data:0'], 'gather_nd:0')