This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 26b2e16  [TF] Fix a bug in _stridedSlice() (#6829)
26b2e16 is described below

commit 26b2e1649db10963bf0725e3b4f9e0cb53d4b9d5
Author: lixiaoquan <[email protected]>
AuthorDate: Wed Nov 4 00:25:24 2020 +0800

    [TF] Fix a bug in _stridedSlice() (#6829)
    
    When stride < 0, the slicing range for whole demension should be
      [-1, -(dim+1))
---
 python/tvm/relay/frontend/tensorflow.py          |  8 ++++++--
 tests/python/frontend/tensorflow/test_forward.py | 10 ++++++++++
 2 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/tensorflow.py 
b/python/tvm/relay/frontend/tensorflow.py
index 2c7adf0..a6fd1db 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1616,11 +1616,15 @@ def _stridedSlice():
                     if final_index == len(m_begin):
                         break
                     if mask & begin_mask:
-                        m_begin[final_index] = data_shape[final_index] if 
stride[index] < 0 else 0
+                        m_begin[final_index] = -1 if stride[index] < 0 else 0
                     elif begin[index]:
                         m_begin[final_index] = begin[index]
                     if mask & end_mask:
-                        m_end[final_index] = 0 if stride[index] < 0 else 
data_shape[final_index]
+                        m_end[final_index] = (
+                            -(data_shape[final_index] + 1)
+                            if stride[index] < 0
+                            else data_shape[final_index]
+                        )
                     elif end[index]:
                         m_end[final_index] = end[index]
                     m_stride[final_index] = stride[index]
diff --git a/tests/python/frontend/tensorflow/test_forward.py 
b/tests/python/frontend/tensorflow/test_forward.py
index 5ec4562..12ec073 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -1880,6 +1880,16 @@ def test_forward_stridedslice():
         begin_mask=5,
         end_mask=8,
     )
+    _test_stridedslice(
+        (1, 13, 13, 3, 2),
+        [0, 0],
+        [1, 1],
+        [1, -1],
+        "float32",
+        ellipsis_mask=1,
+        begin_mask=2,
+        end_mask=2,
+    )
 
 
 #######################################################################

Reply via email to