This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 605a61450b [Unity][Op] Support symbolic shape inference for slice op.
(#15450)
605a61450b is described below
commit 605a61450b770e092da7cd4d50ee52cf6454c107
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Aug 2 09:06:15 2023 +0800
[Unity][Op] Support symbolic shape inference for slice op. (#15450)
This PR improves two things:
1. Support symbolic `begin` and `end` for slice op.
2. Add a new attribute `assume_inbound` for slice op. If `assume_inbound`
is set to True, the slice op will assume the `begin` and `end` are always
inbound, which will simplify the shape deduction.
---
include/tvm/relax/attrs/index.h | 6 +++
python/tvm/relax/op/index.py | 8 +++-
src/relax/op/tensor/index.cc | 36 +++++++++--------
src/relax/op/tensor/index.h | 12 +++---
tests/python/relax/test_op_index.py | 78 +++++++++++++++++++++++++++++++++----
5 files changed, 108 insertions(+), 32 deletions(-)
diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h
index c95395a803..1043fe30ce 100644
--- a/include/tvm/relax/attrs/index.h
+++ b/include/tvm/relax/attrs/index.h
@@ -44,6 +44,7 @@ struct StridedSliceAttrs : public
tvm::AttrsNode<StridedSliceAttrs> {
Array<PrimExpr> begin;
Array<PrimExpr> end;
Optional<Array<PrimExpr>> strides;
+ bool assume_inbound;
TVM_DECLARE_ATTRS(StridedSliceAttrs, "relax.attrs.StridedSliceAttrs") {
TVM_ATTR_FIELD(axes).describe("Axes along which slicing is applied.");
@@ -53,6 +54,11 @@ struct StridedSliceAttrs : public
tvm::AttrsNode<StridedSliceAttrs> {
"Specifies the stride values, it can be negative in that case, the
input tensor will be "
"reversed in that particular axis. If not specified, it by default is
an list of ones of "
"the same length as `axes`.");
+ TVM_ATTR_FIELD(assume_inbound)
+ .set_default(true)
+ .describe(
+ "Whether to assume the indices are in bound. If it is set to
false, "
+ "out of bound indices will be clipped to the bound.");
}
}; // struct StridedSliceAttrs
diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py
index 835c9350b0..8504b4d683 100644
--- a/python/tvm/relax/op/index.py
+++ b/python/tvm/relax/op/index.py
@@ -58,6 +58,7 @@ def strided_slice(
begin: List[PrimExprLike],
end: List[PrimExprLike],
strides: Optional[List[PrimExprLike]] = None,
+ assume_inbound: bool = False,
) -> Expr:
"""Strided slice of a tensor.
@@ -80,6 +81,9 @@ def strided_slice(
the input tensor will be reversed in that particular axis.
If not specified, it by default is an list of ones of the same length
as `axes`.
+ assume_inbound : bool
+ Whether to assume the indices are in bound. If it is set to false,
+ out of bound indices will be clipped to the bound.
Returns
-------
ret : relax.Expr
@@ -90,7 +94,7 @@ def strided_slice(
strided_slice require the input `begin`, `end` and `strides` to have the
same length as `axes`.
"""
- return _ffi_api.strided_slice(x, axes, begin, end, strides) # type: ignore
+ return _ffi_api.strided_slice(x, axes, begin, end, strides,
assume_inbound) # type: ignore
def dynamic_strided_slice(
@@ -99,7 +103,7 @@ def dynamic_strided_slice(
end: Expr,
strides: Expr,
) -> Expr:
- """Dynamic strided slice of a tensor. `begin`, `end`, `strids` can be
computed at runtime.
+ """Dynamic strided slice of a tensor. `begin`, `end`, `strides` can be
computed at runtime.
Parameters
----------
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index a9c61bb56a..5f1d5149b3 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -101,11 +101,12 @@ TVM_REGISTER_OP("relax.take")
/* relax.strided_slice */
TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
-Expr strided_slice(Expr x, //
- Array<Integer> axes, //
- Array<PrimExpr> begin, //
- Array<PrimExpr> end, //
- Optional<Array<PrimExpr>> strides) {
+Expr strided_slice(Expr x, //
+ Array<Integer> axes, //
+ Array<PrimExpr> begin, //
+ Array<PrimExpr> end, //
+ Optional<Array<PrimExpr>> strides, //
+ bool assume_inbound) {
int n_axis = axes.size();
CHECK_EQ(static_cast<int>(begin.size()), n_axis)
<< "StridedSlice requires the number of begin indices to equal the
number of axes.";
@@ -141,6 +142,7 @@ Expr strided_slice(Expr x, //
attrs->begin = begin.Map(f_convert_to_int64);
attrs->end = end.Map(f_convert_to_int64);
attrs->strides = strides.defined() ? strides.value().Map(f_convert_to_int64)
: strides;
+ attrs->assume_inbound = assume_inbound;
static const Op& op = Op::Get("relax.strided_slice");
return Call(op, {std::move(x)}, Attrs(attrs), {});
@@ -148,23 +150,25 @@ Expr strided_slice(Expr x, //
TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice);
-inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, int64_t
stride) {
+inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, int64_t
stride,
+ bool assume_inbound) {
// Same as topi strided slice CanonicalizeIndex function in
// include/tvm/topi/detail/strided_slice.h
PrimExpr begin_range = stride < 0 ? -1 : 0;
PrimExpr end_range = stride < 0 ? extent - 1 : extent;
index = if_then_else(index < 0, index + extent, index);
- return min(max(index, begin_range), end_range); // NOLINT
+ return assume_inbound ? index : min(max(index, begin_range), end_range); //
NOLINT
}
-PrimExpr GetLength(PrimExpr begin, PrimExpr end, const int64_t stride, const
PrimExpr& length) {
- begin = CanonicalizeIndex(begin, length, stride);
- end = CanonicalizeIndex(end, length, stride);
-
+PrimExpr GetLength(PrimExpr begin, PrimExpr end, const int64_t stride, const
PrimExpr& length,
+ bool assume_inbound) {
+ begin = CanonicalizeIndex(begin, length, stride, assume_inbound);
+ end = CanonicalizeIndex(end, length, stride, assume_inbound);
+ arith::Analyzer ana;
if (stride < 0) {
- return ceildiv(begin - end, IntImm(DataType::Int(64), -stride));
+ return ana.Simplify(ceildiv(begin - end, IntImm(DataType::Int(64),
-stride)));
} else {
- return ceildiv(end - begin, IntImm(DataType::Int(64), stride));
+ return ana.Simplify(ceildiv(end - begin, IntImm(DataType::Int(64),
stride)));
}
}
@@ -193,10 +197,8 @@ StructInfo InferStructInfoStridedSlice(const Call& call,
const BlockBuilder& ctx
int_strides.reserve(n_axis);
// Only do output shape inference when all the begin/end/strides values are
integers.
for (int i = 0; i < n_axis; ++i) {
- const auto* int_begin = attrs->begin[i].as<IntImmNode>();
- const auto* int_end = attrs->end[i].as<IntImmNode>();
const auto* int_stride = strides[i].as<IntImmNode>();
- if (!int_begin || !int_end || !int_stride) {
+ if (!int_stride) {
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
}
int_strides.push_back(int_stride->value);
@@ -207,7 +209,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call,
const BlockBuilder& ctx
ICHECK_NE(int_strides[i], 0)
<< "Strided slice requires strides to be non-zero but got 0 for axis "
<< axes[i] << ".";
output_shape.Set(axes[i], GetLength(attrs->begin[i], attrs->end[i],
int_strides[i],
- data_shape->values[axes[i]]));
+ data_shape->values[axes[i]],
attrs->assume_inbound));
}
return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
}
diff --git a/src/relax/op/tensor/index.h b/src/relax/op/tensor/index.h
index 6944493a0f..c8c7428f48 100644
--- a/src/relax/op/tensor/index.h
+++ b/src/relax/op/tensor/index.h
@@ -51,13 +51,15 @@ Expr take(Expr x, Expr indices, Optional<Integer> axis);
* \param strides Specifies the stride values, it can be negative in that case,
* the input tensor will be reversed in that particular axis.
* If it is `NullOpt`, it by default is an list of ones of the same length as
`axes`.
+ * \param assume_inbound Whether to assume the indices are in bound.
* \return The sliced result
*/
-Expr strided_slice(Expr x, //
- Array<Integer> axes, //
- Array<PrimExpr> begin, //
- Array<PrimExpr> end, //
- Optional<Array<PrimExpr>> strides);
+Expr strided_slice(Expr x, //
+ Array<Integer> axes, //
+ Array<PrimExpr> begin, //
+ Array<PrimExpr> end, //
+ Optional<Array<PrimExpr>> strides, //
+ bool assume_inbound = false);
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_op_index.py
b/tests/python/relax/test_op_index.py
index 8b2f8c0b29..cc09b266f5 100644
--- a/tests/python/relax/test_op_index.py
+++ b/tests/python/relax/test_op_index.py
@@ -461,22 +461,22 @@ def test_strided_slice_infer_struct_info_shape_symbolic():
_check_inference(
bb,
relax.op.strided_slice(x0, axes=[0], begin=[1], end=[3]),
- relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m) + 1 - 1, n),
"float32"),
+ relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m), n), "float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x0, axes=[0], begin=[1], end=[8], strides=[3]),
- relax.TensorStructInfo(((tir.min(8, m) - tir.min(1, m) + 3 - 1) // 3,
n), "float32"),
+ relax.TensorStructInfo(((tir.min(8, m) + 2 - tir.min(1, m)) // 3, n),
"float32"),
)
_check_inference(
bb,
relax.op.strided_slice(x1, axes=[0], begin=[1], end=[3]),
- relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m) + 1 - 1, n),
dtype=""),
+ relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m), n), dtype=""),
)
_check_inference(
bb,
relax.op.strided_slice(x1, axes=[0], begin=[1], end=[8], strides=[3]),
- relax.TensorStructInfo(((tir.min(8, m) - tir.min(1, m) + 3 - 1) // 3,
n), dtype=""),
+ relax.TensorStructInfo(((tir.min(8, m) + 2 - tir.min(1, m)) // 3, n),
dtype=""),
)
@@ -549,22 +549,84 @@ def
test_strided_slice_infer_struct_info_more_input_dtype():
def test_strided_slice_infer_struct_info_symbolic_begin_end_strides():
bb = relax.BlockBuilder()
- a = tir.Var("a", "int64")
+ var = tir.Var("var", "int64")
+ size_var = tir.SizeVar("size_var", "int64")
x = relax.Var("x", R.Tensor((8, 9), "float32"))
_check_inference(
bb,
- relax.op.strided_slice(x, axes=[0], begin=[a], end=[8]),
+ relax.op.strided_slice(x, axes=[0], begin=[var], end=[8]),
+ relax.TensorStructInfo(
+ (tir.max(8 - tir.max(tir.if_then_else(var < 0, var + 8, var), 0),
0), 9),
+ dtype="float32",
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.strided_slice(x, axes=[0], begin=[size_var], end=[8]),
+ relax.TensorStructInfo((tir.max(8 - size_var, 0), 9), dtype="float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.strided_slice(x, axes=[0], begin=[0], end=[var]),
+ relax.TensorStructInfo(
+ (tir.min(tir.max(tir.if_then_else(var < 0, var + 8, var), 0), 8),
9), dtype="float32"
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.strided_slice(x, axes=[0], begin=[0], end=[size_var]),
+ relax.TensorStructInfo((tir.min(size_var, 8), 9), dtype="float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[var]),
+ relax.TensorStructInfo(dtype="float32", ndim=2),
+ )
+ _check_inference(
+ bb,
+ relax.op.strided_slice(x, axes=[0], begin=[0], end=[8],
strides=[size_var]),
relax.TensorStructInfo(dtype="float32", ndim=2),
)
+
+
+def test_strided_slice_infer_struct_info_symbolic_begin_end_strides_inbound():
+ bb = relax.BlockBuilder()
+ var = tir.Var("var", "int64")
+ size_var = tir.SizeVar("size_var", "int64")
+ x = relax.Var("x", R.Tensor((8, 9), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.strided_slice(x, axes=[0], begin=[var], end=[8],
assume_inbound=True),
+ relax.TensorStructInfo(
+ (8 - tir.if_then_else(var < 0, var + 8, var), 9),
+ dtype="float32",
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.strided_slice(x, axes=[0], begin=[size_var], end=[8],
assume_inbound=True),
+ relax.TensorStructInfo((8 - size_var, 9), dtype="float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.strided_slice(x, axes=[0], begin=[0], end=[var],
assume_inbound=True),
+ relax.TensorStructInfo((tir.if_then_else(var < 0, var + 8, var), 9),
dtype="float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.strided_slice(x, axes=[0], begin=[0], end=[size_var],
assume_inbound=True),
+ relax.TensorStructInfo((size_var, 9), dtype="float32"),
+ )
_check_inference(
bb,
- relax.op.strided_slice(x, axes=[0], begin=[0], end=[a]),
+ relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[var],
assume_inbound=True),
relax.TensorStructInfo(dtype="float32", ndim=2),
)
_check_inference(
bb,
- relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[a]),
+ relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[var],
assume_inbound=True),
relax.TensorStructInfo(dtype="float32", ndim=2),
)