This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e8caffa489 [Unity][Op] Fix Strided Slice Shape Inference (#14324)
e8caffa489 is described below
commit e8caffa489947002649a4c53b4b2f6ba8dbe53b6
Author: Xiyou Zhou <[email protected]>
AuthorDate: Sun Mar 19 18:46:41 2023 -0700
[Unity][Op] Fix Strided Slice Shape Inference (#14324)
In the case where the begin index of strided slice is out of [-ndim, dim),
the strided slice operator will yield some incorrect shape inference. This PR
corrected this issue by canonicalizing the begin and end index for strided
slice and calculate the symbolic shape. Added some new unit tests for out of
range begin locations and changed some symbolic shape tests.
Co-authored-by: Tianqi Chen <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
---
src/relax/op/tensor/index.cc | 29 ++++++++++++++++++++++----
tests/python/relax/test_op_index.py | 41 +++++++++++++++++++++++++++++++++----
2 files changed, 62 insertions(+), 8 deletions(-)
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index 218de6e2c6..ac3ce084c4 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -142,6 +142,26 @@ Expr strided_slice(Expr x, //
TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice);
+inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, int64_t
stride) {
+ // Same as topi strided slice CanonicalizeIndex function in
+ // include/tvm/topi/detail/strided_slice.h
+ PrimExpr begin_range = stride < 0 ? -1 : 0;
+ PrimExpr end_range = stride < 0 ? extent - 1 : extent;
+ index = if_then_else(index < 0, index + extent, index);
+ return min(max(index, begin_range), end_range); // NOLINT
+}
+
+PrimExpr GetLength(PrimExpr begin, PrimExpr end, const int64_t stride, const
PrimExpr& length) {
+ begin = CanonicalizeIndex(begin, length, stride);
+ end = CanonicalizeIndex(end, length, stride);
+
+ if (stride < 0) {
+ return ceildiv(begin - end, IntImm(DataType::Int(64), -stride));
+ } else {
+ return ceildiv(end - begin, IntImm(DataType::Int(64), stride));
+ }
+}
+
StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder&
ctx) {
TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
const auto* attrs = call->attrs.as<StridedSliceAttrs>();
@@ -163,7 +183,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call,
const BlockBuilder& ctx
Array<PrimExpr> strides = attrs->strides.defined()
? attrs->strides.value()
: Array<PrimExpr>(n_axis,
IntImm(DataType::Int(64), 1));
- std::vector<int> int_strides;
+ std::vector<int64_t> int_strides;
int_strides.reserve(n_axis);
// Only do output shape inference when all the begin/end/stride values are
integers.
for (int i = 0; i < n_axis; ++i) {
@@ -178,9 +198,10 @@ StructInfo InferStructInfoStridedSlice(const Call& call,
const BlockBuilder& ctx
Array<PrimExpr> output_shape = data_shape->values;
for (int i = 0; i < n_axis; ++i) {
- PrimExpr len = int_strides[i] < 0 ? ceildiv(attrs->begin[i] -
attrs->end[i], -int_strides[i])
- : ceildiv(attrs->end[i] -
attrs->begin[i], int_strides[i]);
- output_shape.Set(axes[i], len);
+ ICHECK_NE(int_strides[i], 0)
+ << "Strided slice requires stride to be non-zero but got 0 for axis "
<< axes[i] << ".";
+ output_shape.Set(axes[i], GetLength(attrs->begin[i], attrs->end[i],
int_strides[i],
+ data_shape->values[axes[i]]));
}
return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
}
diff --git a/tests/python/relax/test_op_index.py
b/tests/python/relax/test_op_index.py
index 77a04b1a1a..a84e70a0eb 100644
--- a/tests/python/relax/test_op_index.py
+++ b/tests/python/relax/test_op_index.py
@@ -366,6 +366,39 @@ def test_strided_slice_infer_struct_info():
)
+def test_strided_slice_infer_struct_info_shape_out_of_range():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((20, 10, 5), "float32"))
+ _check_inference(
+ bb,
+ relax.op.strided_slice(
+ x0, axes=[0, 1, 2], begin=[20, 10, 4], end=[0, 0, 1], strides=[-1,
-3, -2]
+ ),
+ relax.TensorStructInfo((19, 3, 2), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.strided_slice(
+ x0, axes=[0, 1, 2], begin=[200, 10, 4], end=[0, 0, 1],
strides=[-1, -3, -2]
+ ),
+ relax.TensorStructInfo((19, 3, 2), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.strided_slice(
+ x0, axes=[0, 1, 2], begin=[200, 10, 100], end=[0, 0, 1],
strides=[-1, -3, -5]
+ ),
+ relax.TensorStructInfo((19, 3, 1), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.strided_slice(
+ x0, axes=[0, 1, 2], begin=[-21, -11, -6], end=[1, 1, 1],
strides=[1000, 1000, 1000]
+ ),
+ relax.TensorStructInfo((1, 1, 1), "float32"),
+ )
+
+
def test_strided_slice_infer_struct_info_shape_symbolic():
bb = relax.BlockBuilder()
m = tir.Var("m", "int64")
@@ -376,22 +409,22 @@ def test_strided_slice_infer_struct_info_shape_symbolic():
_check_inference(
bb,
relax.op.strided_slice(x0, axes=[0], begin=[1], end=[3]),
- relax.TensorStructInfo((2, n), "float32"),
+ relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m) + 1 - 1, n),
"float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x0, axes=[0], begin=[1], end=[8], strides=[3]),
- relax.TensorStructInfo((3, n), "float32"),
+ relax.TensorStructInfo(((tir.min(8, m) - tir.min(1, m) + 3 - 1) // 3,
n), "float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x1, axes=[0], begin=[1], end=[3]),
- relax.TensorStructInfo((2, n), dtype=""),
+ relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m) + 1 - 1, n),
dtype=""),
)
_check_inference(
bb,
relax.op.strided_slice(x1, axes=[0], begin=[1], end=[8], strides=[3]),
- relax.TensorStructInfo((3, n), dtype=""),
+ relax.TensorStructInfo(((tir.min(8, m) - tir.min(1, m) + 3 - 1) // 3,
n), dtype=""),
)