This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 65606c9 [ONNX] Normalize axes for Slice (#9517)
65606c9 is described below
commit 65606c9ec0f74479f4c3fed22788fa1e326b18b1
Author: Matthew Brookhart <[email protected]>
AuthorDate: Wed Nov 17 20:40:17 2021 -0700
[ONNX] Normalize axes for Slice (#9517)
---
python/tvm/relay/frontend/onnx.py | 10 +++++++++-
tests/python/frontend/onnx/test_forward.py | 1 +
2 files changed, 10 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relay/frontend/onnx.py
b/python/tvm/relay/frontend/onnx.py
index 0dc08d5..e8c8e28 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1522,6 +1522,15 @@ class Slice(OnnxOpConverter):
ishape = infer_shape(inputs[0])
data_rank = len(ishape)
+ if axes is not None:
+ # Normalize for negative axes
+ axes_dtype = infer_type(axes).checked_type.dtype
+ axes = fold_constant(
+ _op.where(
+ axes < _op.const(0, axes_dtype), axes +
_op.const(data_rank, axes_dtype), axes
+ )
+ )
+
def has_static_axes():
return (
isinstance(axes, _expr.Constant)
@@ -1538,7 +1547,6 @@ class Slice(OnnxOpConverter):
strides_np = np.ones_like(begin_np).astype("int64")
else:
strides_np = steps.data.numpy().astype("int64")
-
if all([isinstance(ishape[i], int) for i in axes_np]):
return _op.strided_slice(
inputs[0], list(begin_np), list(end_np), list(strides_np),
axes=list(axes_np)
diff --git a/tests/python/frontend/onnx/test_forward.py
b/tests/python/frontend/onnx/test_forward.py
index 0531bfc..27d1242 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -894,6 +894,7 @@ def test_slice(target, dev):
_test_slice_iteration_v10(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10,
4))
_test_slice_iteration_v10(x, x[:, 1:1000], starts=(1,), ends=(1000,),
axes=(1,))
_test_slice_iteration_v10(x, x[:, 0:-1], starts=(0,), ends=(-1,),
axes=(1,))
+ _test_slice_iteration_v10(x, x[:, 0:-1], starts=(0,), ends=(-1,),
axes=(-1,))
_test_slice_iteration_v10(
x,
x[0:3, 0:10],