This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 65606c9  [ONNX] Normalize axes for Slice (#9517)
65606c9 is described below

commit 65606c9ec0f74479f4c3fed22788fa1e326b18b1
Author: Matthew Brookhart <[email protected]>
AuthorDate: Wed Nov 17 20:40:17 2021 -0700

    [ONNX] Normalize axes for Slice (#9517)
---
 python/tvm/relay/frontend/onnx.py          | 10 +++++++++-
 tests/python/frontend/onnx/test_forward.py |  1 +
 2 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index 0dc08d5..e8c8e28 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1522,6 +1522,15 @@ class Slice(OnnxOpConverter):
         ishape = infer_shape(inputs[0])
         data_rank = len(ishape)
 
+        if axes is not None:
+            # Normalize for negative axes
+            axes_dtype = infer_type(axes).checked_type.dtype
+            axes = fold_constant(
+                _op.where(
+                    axes < _op.const(0, axes_dtype), axes + 
_op.const(data_rank, axes_dtype), axes
+                )
+            )
+
         def has_static_axes():
             return (
                 isinstance(axes, _expr.Constant)
@@ -1538,7 +1547,6 @@ class Slice(OnnxOpConverter):
                 strides_np = np.ones_like(begin_np).astype("int64")
             else:
                 strides_np = steps.data.numpy().astype("int64")
-
             if all([isinstance(ishape[i], int) for i in axes_np]):
                 return _op.strided_slice(
                     inputs[0], list(begin_np), list(end_np), list(strides_np), 
axes=list(axes_np)
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index 0531bfc..27d1242 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -894,6 +894,7 @@ def test_slice(target, dev):
     _test_slice_iteration_v10(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 
4))
     _test_slice_iteration_v10(x, x[:, 1:1000], starts=(1,), ends=(1000,), 
axes=(1,))
     _test_slice_iteration_v10(x, x[:, 0:-1], starts=(0,), ends=(-1,), 
axes=(1,))
+    _test_slice_iteration_v10(x, x[:, 0:-1], starts=(0,), ends=(-1,), 
axes=(-1,))
     _test_slice_iteration_v10(
         x,
         x[0:3, 0:10],

Reply via email to