You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/11/03 16:25:38 UTC

[incubator-tvm] branch main updated: [TF] Fix a bug in _stridedSlice() (#6829)

This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 26b2e16  [TF] Fix a bug in _stridedSlice() (#6829)
26b2e16 is described below

commit 26b2e1649db10963bf0725e3b4f9e0cb53d4b9d5
Author: lixiaoquan <ra...@163.com>
AuthorDate: Wed Nov 4 00:25:24 2020 +0800

    [TF] Fix a bug in _stridedSlice() (#6829)
    
    When stride < 0, the slicing range for whole demension should be
      [-1, -(dim+1))
---
 python/tvm/relay/frontend/tensorflow.py          |  8 ++++++--
 tests/python/frontend/tensorflow/test_forward.py | 10 ++++++++++
 2 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index 2c7adf0..a6fd1db 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1616,11 +1616,15 @@ def _stridedSlice():
                     if final_index == len(m_begin):
                         break
                     if mask & begin_mask:
-                        m_begin[final_index] = data_shape[final_index] if stride[index] < 0 else 0
+                        m_begin[final_index] = -1 if stride[index] < 0 else 0
                     elif begin[index]:
                         m_begin[final_index] = begin[index]
                     if mask & end_mask:
-                        m_end[final_index] = 0 if stride[index] < 0 else data_shape[final_index]
+                        m_end[final_index] = (
+                            -(data_shape[final_index] + 1)
+                            if stride[index] < 0
+                            else data_shape[final_index]
+                        )
                     elif end[index]:
                         m_end[final_index] = end[index]
                     m_stride[final_index] = stride[index]
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 5ec4562..12ec073 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -1880,6 +1880,16 @@ def test_forward_stridedslice():
         begin_mask=5,
         end_mask=8,
     )
+    _test_stridedslice(
+        (1, 13, 13, 3, 2),
+        [0, 0],
+        [1, 1],
+        [1, -1],
+        "float32",
+        ellipsis_mask=1,
+        begin_mask=2,
+        end_mask=2,
+    )
 
 
 #######################################################################