You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/11/03 16:25:38 UTC
[incubator-tvm] branch main updated: [TF] Fix a bug in
_stridedSlice() (#6829)
This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 26b2e16 [TF] Fix a bug in _stridedSlice() (#6829)
26b2e16 is described below
commit 26b2e1649db10963bf0725e3b4f9e0cb53d4b9d5
Author: lixiaoquan <ra...@163.com>
AuthorDate: Wed Nov 4 00:25:24 2020 +0800
[TF] Fix a bug in _stridedSlice() (#6829)
When stride < 0, the slicing range for whole demension should be
[-1, -(dim+1))
---
python/tvm/relay/frontend/tensorflow.py | 8 ++++++--
tests/python/frontend/tensorflow/test_forward.py | 10 ++++++++++
2 files changed, 16 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index 2c7adf0..a6fd1db 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1616,11 +1616,15 @@ def _stridedSlice():
if final_index == len(m_begin):
break
if mask & begin_mask:
- m_begin[final_index] = data_shape[final_index] if stride[index] < 0 else 0
+ m_begin[final_index] = -1 if stride[index] < 0 else 0
elif begin[index]:
m_begin[final_index] = begin[index]
if mask & end_mask:
- m_end[final_index] = 0 if stride[index] < 0 else data_shape[final_index]
+ m_end[final_index] = (
+ -(data_shape[final_index] + 1)
+ if stride[index] < 0
+ else data_shape[final_index]
+ )
elif end[index]:
m_end[final_index] = end[index]
m_stride[final_index] = stride[index]
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 5ec4562..12ec073 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -1880,6 +1880,16 @@ def test_forward_stridedslice():
begin_mask=5,
end_mask=8,
)
+ _test_stridedslice(
+ (1, 13, 13, 3, 2),
+ [0, 0],
+ [1, 1],
+ [1, -1],
+ "float32",
+ ellipsis_mask=1,
+ begin_mask=2,
+ end_mask=2,
+ )
#######################################################################