This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 7db0b984de [Unity][Op] Dynamic Strided Slice (#14548)
7db0b984de is described below
commit 7db0b984de6dd7d9ee7db79681476bc2bb7f26d5
Author: Sunghyun Park <[email protected]>
AuthorDate: Wed Apr 12 13:04:04 2023 -0700
[Unity][Op] Dynamic Strided Slice (#14548)
* feat: dyn_strided_slice op
* feat: shape computation
* feat: legalizer for dynamic strided slice
* remove whitespace
* reflect feedback
* fix
* fix
* remove whitespace
---
include/tvm/topi/transform.h | 37 +-
python/tvm/relax/op/index.py | 37 ++
python/tvm/relax/transform/legalize_ops/index.py | 68 ++-
python/tvm/script/ir_builder/relax/ir.py | 3 +-
python/tvm/topi/transform.py | 35 ++
src/relax/op/tensor/index.cc | 78 +++-
src/relax/transform/legalize_ops.cc | 10 +-
src/topi/transform.cc | 8 +
tests/python/relax/test_e2e_op_dynamic.py | 104 +++++
tests/python/relax/test_op_index.py | 197 +++++++++
..._transform_legalize_ops_index_linear_algebra.py | 489 +++++++++++++++++++++
tests/python/topi/python/test_topi_transform.py | 54 +++
12 files changed, 1111 insertions(+), 9 deletions(-)
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 0490ae7f1e..579dbb5833 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -2021,7 +2021,6 @@ inline Tensor adv_index(const Tensor& data, const
Array<Tensor>& indices,
for (size_t i = 0; i < broadcast_shape.size(); ++i) {
tensor_indices.push_back(iter_var[i]);
}
-
Array<PrimExpr> real_indices;
for (size_t i = 0; i < bindices.size(); ++i) {
real_indices.push_back(bindices[i](tensor_indices));
@@ -2035,6 +2034,42 @@ inline Tensor adv_index(const Tensor& data, const
Array<Tensor>& indices,
name, tag);
}
+namespace relax {
+// relax dynamic slice
+inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor&
begin,
+ const te::Tensor& end, const
te::Tensor& strides,
+ Array<PrimExpr> output_shape,
+ std::string name =
"T_strided_slice_dynamic",
+ std::string tag = kInjective) {
+ const size_t num_dynamic_axes = x.ndim();
+ ICHECK_EQ(begin.ndim(), 1);
+ ICHECK_EQ(end.ndim(), 1);
+ ICHECK_EQ(strides.ndim(), 1);
+ const auto* len_begin = begin->shape[0].as<IntImmNode>();
+ const auto* len_end = end->shape[0].as<IntImmNode>();
+ const auto* len_strides = strides->shape[0].as<IntImmNode>();
+ ICHECK(len_begin);
+ ICHECK(len_end);
+ ICHECK(len_strides);
+ ICHECK_EQ(len_begin->value, num_dynamic_axes);
+ ICHECK_EQ(len_end->value, num_dynamic_axes);
+ ICHECK_EQ(len_strides->value, num_dynamic_axes);
+
+ return te::compute(
+ output_shape,
+ [&](const Array<tvm::tir::Var>& indices) {
+ Array<PrimExpr> real_indices;
+ for (size_t i = 0; i < num_dynamic_axes; ++i) {
+ auto ind = make_const(DataType::Int(64), i);
+ real_indices.push_back(indices[i] * strides(ind) +
tvm::min(begin(ind), x->shape[i] - 1));
+ }
+ return x(real_indices);
+ },
+ name, tag);
+}
+
+} // namespace relax
+
} // namespace topi
} // namespace tvm
#endif // TVM_TOPI_TRANSFORM_H_
diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py
index b9acf11b9f..835c9350b0 100644
--- a/python/tvm/relax/op/index.py
+++ b/python/tvm/relax/op/index.py
@@ -91,3 +91,40 @@ def strided_slice(
same length as `axes`.
"""
return _ffi_api.strided_slice(x, axes, begin, end, strides) # type: ignore
+
+
+def dynamic_strided_slice(
+ x: Expr,
+ begin: Expr,
+ end: Expr,
+ strides: Expr,
+) -> Expr:
+ """Dynamic strided slice of a tensor. `begin`, `end`, `strids` can be
computed at runtime.
+
+ Parameters
+ ----------
+ x : Expr
+ The source tensor to be sliced.
+
+ begin : Expr
+ The indices to begin with in the slicing, inclusive.
+
+ end : Expr
+ The indices indicating end of the slice, exclusive.
+
+ strides : Expr
+ Specifies the stride values, it can be negative in that case,
+ the input tensor will be reversed in that particular axis.
+ If not specified, it by default is an list of ones of the same length
as `axes`.
+
+ Returns
+ -------
+ ret : relax.Expr
+ The sliced result.
+
+ Note
+ ----
+ dyn_strided_slice require the input `begin`, `end` and `strides` to have
the
+ same length as rank of `data` tensor.
+ """
+ return _ffi_api.dynamic_strided_slice(x, begin, end, strides) # type:
ignore
diff --git a/python/tvm/relax/transform/legalize_ops/index.py
b/python/tvm/relax/transform/legalize_ops/index.py
index eccccc7c6d..8ee1bed9b9 100644
--- a/python/tvm/relax/transform/legalize_ops/index.py
+++ b/python/tvm/relax/transform/legalize_ops/index.py
@@ -18,9 +18,10 @@
"""Default legalization function for index operators."""
import logging
-from tvm import topi, tir
+from tvm import topi, tir, te
from ...block_builder import BlockBuilder
-from ...expr import Call, Expr
+from ...expr import Call, Expr, ExternFunc
+from ...struct_info import ShapeStructInfo
from .common import register_legalize
@@ -59,3 +60,66 @@ def _strided_slice(bb: BlockBuilder, call: Call) -> Expr:
call.attrs.axes,
slice_mode="end",
)
+
+
+@register_legalize("relax.dynamic_strided_slice")
+def _dynamic_strided_slice(bb: BlockBuilder, call: Call) -> Expr:
+ assert len(call.args) == 4
+ data, begin, end, strides = call.args
+
+ # 1. Insert shape function
+ def shape_func(data, begin, end, strides):
+ def _compute(i):
+ def canonicalize_index(index, extent, strides):
+ begin_range = tir.Select(strides < 0, tir.const(-1, "int64"),
tir.const(0, "int64"))
+ end_range = tir.Select(strides < 0, extent - 1, extent)
+ index = tir.Select(index < 0, index + extent, index)
+ return tir.Min(tir.Max(index, begin_range), end_range)
+
+ def get_length(begin, end, strides, length):
+ begin = canonicalize_index(begin, length, strides)
+ end = canonicalize_index(end, length, strides)
+ len1 = tir.ceildiv(begin - end, -strides)
+ len2 = tir.ceildiv(end - begin, strides)
+ return tir.Select(strides < 0, len1, len2)
+
+ length = tir.const(-1, "int64")
+ for idx in range(data.ndim):
+ length = tir.Select(i == tir.const(idx, "int64"),
data.shape[idx], length)
+
+ return get_length(begin[i], end[i], strides[i], length)
+
+ return te.compute((begin.shape[0],), _compute,
name="T_shape_func_strided_slice_dynamic")
+
+ output_shape = bb.normalize(
+ bb.call_te(
+ shape_func,
+ data,
+ begin,
+ end,
+ strides,
+ )
+ )
+
+ # 2. Convert tensor to shape and match cast with new symbolic vars
+ # Get shape length
+ ndim = int(output_shape.struct_info.shape[0])
+ output_shape = bb.emit(
+ Call(
+ ExternFunc("vm.builtin.tensor_to_shape"),
+ [output_shape],
+ sinfo_args=[ShapeStructInfo(ndim=ndim)],
+ )
+ )
+ output_shape_vars = [tir.Var("s", "int64") for i in range(ndim)]
+ bb.match_cast(output_shape, ShapeStructInfo(output_shape_vars))
+
+ # 3. Pass the output shape vars to TOPI
+ return bb.call_te(
+ topi.dynamic_strided_slice,
+ call.args[0],
+ call.args[1],
+ call.args[2],
+ call.args[3],
+ output_shape=output_shape_vars,
+ )
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index e390658c8f..39327c4b4a 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -111,6 +111,7 @@ from tvm.relax.op import (
shape_of,
std,
strided_slice,
+ dynamic_strided_slice,
sum,
take,
variance,
@@ -639,7 +640,6 @@ __all__ = [
"ShapeExpr",
"std",
"str",
- "strided_slice",
"sum",
"sigmoid",
"sign",
@@ -652,6 +652,7 @@ __all__ = [
"stop_lift_params",
"str",
"strided_slice",
+ "dynamic_strided_slice",
"subtract",
"take",
"tan",
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index e4fe3c5839..7807351e90 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -227,6 +227,41 @@ def strided_slice(a, begin, end, strides=None, axes=None,
slice_mode="end"):
return cpp.strided_slice(a, begin, end, strides, axes, slice_mode)
+def dynamic_strided_slice(a, begin, end, strides, output_shape):
+ """Slice of an array.
+
+ Parameters
+ ----------
+ a : tvm.te.Tensor
+ The tensor to be sliced.
+
+ begin : tvm.te.Tensor
+ The indices to begin with in the slicing.
+
+ end : tvm.te.Tensor
+ Indices indicating end of the slice.
+
+ strides : tvm.te.Tensor
+ Specifies the stride values, it can be negative
+ in that case, the input tensor will be reversed
+ in that particular axis.
+
+ output_shape: list of PrimExpr
+ Specifies the output shape
+
+ Returns
+ -------
+ ret : tvm.te.Tensor
+ """
+ if not isinstance(begin, tvm.te.Tensor):
+ begin = const_vector(begin)
+ if not isinstance(end, tvm.te.Tensor):
+ end = const_vector(end)
+ if not isinstance(strides, tvm.te.Tensor):
+ strides = const_vector(strides)
+ return cpp.relax_dynamic_strided_slice(a, begin, end, strides,
output_shape)
+
+
@tvm.te.tag_scope(tag=tag.INJECTIVE + ",strided_set")
def strided_set(a, v, begin, end, strides=None):
"""Set slice of an array.
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index b627e1425b..c3d38db4e1 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -190,7 +190,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call,
const BlockBuilder& ctx
: Array<PrimExpr>(n_axis,
IntImm(DataType::Int(64), 1));
std::vector<int64_t> int_strides;
int_strides.reserve(n_axis);
- // Only do output shape inference when all the begin/end/stride values are
integers.
+ // Only do output shape inference when all the begin/end/strides values are
integers.
for (int i = 0; i < n_axis; ++i) {
const auto* int_begin = attrs->begin[i].as<IntImmNode>();
const auto* int_end = attrs->end[i].as<IntImmNode>();
@@ -204,7 +204,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call,
const BlockBuilder& ctx
Array<PrimExpr> output_shape = data_shape->values;
for (int i = 0; i < n_axis; ++i) {
ICHECK_NE(int_strides[i], 0)
- << "Strided slice requires stride to be non-zero but got 0 for axis "
<< axes[i] << ".";
+ << "Strided slice requires strides to be non-zero but got 0 for axis "
<< axes[i] << ".";
output_shape.Set(axes[i], GetLength(attrs->begin[i], attrs->end[i],
int_strides[i],
data_shape->values[axes[i]]));
}
@@ -239,5 +239,79 @@ TVM_REGISTER_OP("relax.strided_slice")
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow);
+/* relax.dynamic_strided_slice */
+Expr dynamic_strided_slice(Expr x, //
+ Expr begin, //
+ Expr end, //
+ Expr strides) {
+ static const Op& op = Op::Get("relax.dynamic_strided_slice");
+ return Call(op, {std::move(x), std::move(begin), std::move(end),
std::move(strides)}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice);
+
+StructInfo InferStructInfoDynStridedSlice(const Call& call, const
BlockBuilder& ctx) {
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ const auto* begin_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+ const auto* end_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
+ const auto* strides_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[3]);
+
+ ICHECK(data_sinfo);
+ if (data_sinfo->IsUnknownNdim()) {
+ LOG(WARNING) << "When data rank is unknown, dynamic strided slice assumes
begin/end/strides "
+ "tensors are well-formed. It could produce runtime error
when this assumption "
+ "turns out to be wrong.";
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
+ }
+ if (data_sinfo->IsUnknownDtype()) {
+ LOG(WARNING) << "When data type is unknown, dynamic strided slice assumes
to have a valid "
+ "dtype. It could produce runtime error when this
assumption "
+ "turns out to be wrong.";
+ }
+
+ int n_axis = data_sinfo->ndim;
+ auto diag_def = [&](const TensorStructInfoNode* sinfo, String name) {
+ ICHECK(sinfo) << "Dynamic strided slice requires the input " << name
+ << " to be have the struct info. Please try normalizing the
inputs.";
+ CHECK_EQ(sinfo->ndim, 1) << "Dynamic strided slice requires " << name
+ << " to be 1d tensor (list of values).";
+ const auto* shape = sinfo->shape.as<ShapeExprNode>();
+ ICHECK(shape) << "Dynamic strided slice requires the input " << name
+ << " to have well-defined shape.";
+ // NOTE(tvm-team): This strong restriction seems necessary for now until
we have a generic
+ // solution in converting 1d Tensor with unknown num_elem to
Array<PrimExpr>.
+ const auto* num_elem = shape->values[0].as<IntImmNode>();
+ ICHECK(num_elem) << "Dynamic strided slice requires the input " << name
+ << " to have a known integer shape value.";
+ CHECK_EQ(num_elem->value, n_axis) << "Dynamic strided slice requires the
number of indices in "
+ << name << " to equal the number of
axes.";
+ if (sinfo->IsUnknownDtype()) {
+ LOG(WARNING) << "Dynamic strided slice assumes " << name
+ << " to be int64 when it is not specified.";
+ } else {
+ CHECK(sinfo->dtype == DataType::Int(64))
+ << "Dynamic strided_slice expects the input " << name
+ << "values to be all int64. However, " << name << " has dtype " <<
sinfo->dtype << ".";
+ }
+ };
+ diag_def(begin_sinfo, "begin");
+ diag_def(end_sinfo, "end");
+ diag_def(strides_sinfo, "strides");
+
+ // The output shape will depend on the runtime value in begin/end/strides
tensors.
+ // TODO(tvm-team): Currently, it is unable to express partially-static
shape. Revisit when
+ // PrimValue lands.
+ return TensorStructInfo(data_sinfo->dtype, n_axis);
+} // namespace relax
+
+// TODO(tvm-team): Register FRelaxInferLayout, TMixedPrecisionPolicy
+TVM_REGISTER_OP("relax.dynamic_strided_slice")
+ .set_num_inputs(4)
+ .add_argument("x", "Tensor", "The source tensor to be sliced.")
+ .add_argument("begin", "Tensor", "The indices to begin with in the
slicing.")
+ .add_argument("end", "Tensor", "Indices indicating end of the slice.")
+ .add_argument("strides", "Tensor", "The stride values.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoDynStridedSlice);
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/transform/legalize_ops.cc
b/src/relax/transform/legalize_ops.cc
index 350a40c37b..7c5393c6ca 100644
--- a/src/relax/transform/legalize_ops.cc
+++ b/src/relax/transform/legalize_ops.cc
@@ -83,15 +83,19 @@ class LegalizeMutator : public ExprMutator {
return visited_call;
}
+ auto op = GetRef<Op>(op_node);
+ std::string op_name(op->name);
+ bool is_data_dependent_op = (op_name.find("dynamic") != std::string::npos);
// Not all shape values are known
+ // Data-dependent ops are exception since their output shape will be
identified at runtime.
+ // Legalizer will insert their shape functions, which are manually
registered, and match cast
+ // to define symbolic output shape at compile time.
if (!std::all_of(visited_call->args.begin(), visited_call->args.end(),
[](Expr arg) { return
KnowAllShapeValues(GetStructInfo(arg)); }) ||
- !KnowAllShapeValues(GetStructInfo(visited_call))) {
+ (!is_data_dependent_op &&
!KnowAllShapeValues(GetStructInfo(visited_call)))) {
return visited_call;
}
- auto op = GetRef<Op>(op_node);
-
// Priority: customize > default.
// Check if it has customize legalization registered.
if (cmap_.defined() && cmap_.value().count(op->name)) {
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index bbefa19c20..906f2b9115 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -201,6 +201,14 @@
TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice").set_body([](TVMArgs args, TVMR
*rv = dynamic_strided_slice(args[0], begin, end, strides);
});
+TVM_REGISTER_GLOBAL("topi.relax_dynamic_strided_slice").set_body([](TVMArgs
args, TVMRetValue* rv) {
+ te::Tensor begin = args[1];
+ te::Tensor end = args[2];
+ te::Tensor strides = args[3];
+ Array<PrimExpr> output_shape = args[4];
+ *rv = relax::dynamic_strided_slice(args[0], begin, end, strides,
output_shape);
+});
+
TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv)
{
int depth = args[3];
int axis = args[4];
diff --git a/tests/python/relax/test_e2e_op_dynamic.py
b/tests/python/relax/test_e2e_op_dynamic.py
new file mode 100644
index 0000000000..1e9414c15d
--- /dev/null
+++ b/tests/python/relax/test_e2e_op_dynamic.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pytest
+import tvm
+from tvm import relax
+import tvm.topi.testing
+from tvm.relax.transform import LegalizeOps
+from tvm.script import relax as R, tir as T
+import tvm.testing
+
+# TODO(tvm-team): `tir.transform.DefaultGPUSchedule` does not work.
+target, dev = "llvm", tvm.cpu()
+
+
+def build(mod):
+ exe = relax.build(mod, target=target)
+ return relax.VirtualMachine(exe, dev)
+
+
[email protected](
+ "begin, end, strides",
+ [
+ ([0, 2, 4, 4], [5, 5, 7, 8], [1, 1, 2, 3]),
+ ([0, 2, 4, 4], [5, 5, 11, 10], [1, 1, 1, 1]),
+ ([0, 2, 10, 14], [0, 5, 1, 1], [1, 1, -1, -2]),
+ ],
+)
+def test_dynamic_strided_slice(begin, end, strides):
+ # fmt: off
+ @tvm.script.ir_module
+ class DynamicStridedSlice:
+ @R.function
+ def main(x: R.Tensor((8, 9, 10, 10), "float32"), begin:
R.Tensor((4,),"int64"), end: R.Tensor((4,),"int64"), strides:
R.Tensor((4,),"int64")) -> R.Tensor("float32", ndim=4):
+ gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x,
begin, end, strides)
+ return gv
+ # fmt: on
+ mod = LegalizeOps()(DynamicStridedSlice)
+ with tvm.target.Target(target):
+ mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
+ vm = build(mod)
+
+ x_np = np.random.rand(8, 9, 10, 10).astype(np.float32)
+ data_nd = tvm.nd.array(x_np, dev)
+ begin_nd = tvm.nd.array(np.array(begin).astype("int64"), dev)
+ end_nd = tvm.nd.array(np.array(end).astype("int64"), dev)
+ strides_nd = tvm.nd.array(np.array(strides).astype("int64"), dev)
+
+ # Reference implementation
+ out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides)
+ out_nd = vm["main"](data_nd, begin_nd, end_nd, strides_nd)
+ tvm.testing.assert_allclose(out_nd.numpy(), out_npy)
+
+
[email protected](
+ "begin, end, strides",
+ [
+ ([0, 2, 4, 4], [5, 5, 7, 8], [1, 1, 2, 3]),
+ ([0, 2, 4, 4], [5, 5, 11, 10], [1, 1, 1, 1]),
+ ([0, 2, 10, 14], [0, 5, 1, 1], [1, 1, -1, -2]),
+ ],
+)
+def test_dynamic_strided_slice_symbolic(begin, end, strides):
+ # fmt: off
+ @tvm.script.ir_module
+ class DynamicStridedSlice:
+ @R.function
+ def main(x: R.Tensor(("m", "n", 10, 10), "float32"), begin:
R.Tensor((4,),"int64"), end: R.Tensor((4,),"int64"), strides:
R.Tensor((4,),"int64")) -> R.Tensor("float32", ndim=4):
+ m = T.int64()
+ n = T.int64()
+ gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x,
begin, end, strides)
+ return gv
+ # fmt: on
+ mod = LegalizeOps()(DynamicStridedSlice)
+ vm = build(mod)
+
+ x_np = np.random.rand(8, 9, 10, 10).astype(np.float32)
+ data_nd = tvm.nd.array(x_np, dev)
+ begin_nd = tvm.nd.array(np.array(begin).astype("int64"), dev)
+ end_nd = tvm.nd.array(np.array(end).astype("int64"), dev)
+ strides_nd = tvm.nd.array(np.array(strides).astype("int64"), dev)
+
+ # Reference implementation
+ out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides)
+ out_nd = vm["main"](data_nd, begin_nd, end_nd, strides_nd)
+ tvm.testing.assert_allclose(out_nd.numpy(), out_npy)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_op_index.py
b/tests/python/relax/test_op_index.py
index 9390a2c9b0..8b2f8c0b29 100644
--- a/tests/python/relax/test_op_index.py
+++ b/tests/python/relax/test_op_index.py
@@ -30,6 +30,7 @@ def test_op_correctness():
assert relax.op.strided_slice(x, axes=[0], begin=[0], end=[2]).op ==
Op.get(
"relax.strided_slice"
)
+ assert relax.op.dynamic_strided_slice(x, x, x, x).op ==
Op.get("relax.dynamic_strided_slice")
def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo:
relax.StructInfo):
@@ -673,5 +674,201 @@ def
test_strided_slice_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8]))
+def test_dynamic_strided_slice_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+ x2 = relax.Var("x", R.Tensor("float32"))
+ x3 = relax.Var("x", R.Tensor((8, 9, 10, 10)))
+ x4 = relax.Var("x", R.Tensor(ndim=4))
+ x5 = relax.Var("x", R.Tensor())
+
+ b0 = relax.Var("begin", R.Tensor((4,), "int64"))
+ e0 = relax.Var("end", R.Tensor((4,), "int64"))
+ s0 = relax.Var("strides", R.Tensor((4,), "int64"))
+ b1 = relax.Var("begin", R.Tensor((4,)))
+ e1 = relax.Var("end", R.Tensor((4,)))
+ s1 = relax.Var("stride", R.Tensor((4,)))
+
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x0, b0, e0, s0),
+ R.Tensor("float32", ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x1, b0, e0, s0),
+ R.Tensor("float32", ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x2, b0, e0, s0),
+ R.Tensor("float32", ndim=-1),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x3, b0, e0, s0),
+ R.Tensor(ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x4, b0, e0, s0),
+ R.Tensor(ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x5, b0, e0, s0),
+ R.Tensor(ndim=-1),
+ )
+
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x0, b1, e1, s1),
+ R.Tensor("float32", ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x1, b1, e1, s1),
+ R.Tensor("float32", ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x2, b1, e1, s1),
+ R.Tensor("float32", ndim=-1),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x3, b1, e1, s1),
+ R.Tensor(ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x4, b1, e1, s1),
+ R.Tensor(ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x5, b1, e1, s1),
+ R.Tensor(ndim=-1),
+ )
+
+
+def test_dynamic_strided_slice_infer_struct_info_symbolic():
+ bb = relax.BlockBuilder()
+ i = tir.Var("i", "int64")
+ j = tir.Var("j", "int64")
+ k = tir.Var("k", "int64")
+ l = tir.Var("l", "int64")
+ x0 = relax.Var("x", R.Tensor((i, j, k, l), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+ x2 = relax.Var("x", R.Tensor("float32"))
+ x3 = relax.Var("x", R.Tensor((i, j, k, l)))
+ x4 = relax.Var("x", R.Tensor(ndim=4))
+ x5 = relax.Var("x", R.Tensor())
+
+ b0 = relax.Var("begin", R.Tensor((4,), "int64"))
+ e0 = relax.Var("end", R.Tensor((4,), "int64"))
+ s0 = relax.Var("stride", R.Tensor((4,), "int64"))
+ b1 = relax.Var("begin", R.Tensor((4,)))
+ e1 = relax.Var("end", R.Tensor((4,)))
+ s1 = relax.Var("stride", R.Tensor((4,)))
+
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x0, b0, e0, s0),
+ R.Tensor("float32", ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x1, b0, e0, s0),
+ R.Tensor("float32", ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x2, b0, e0, s0),
+ R.Tensor("float32", ndim=-1),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x3, b0, e0, s0),
+ R.Tensor(ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x4, b0, e0, s0),
+ R.Tensor(ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x5, b0, e0, s0),
+ R.Tensor(ndim=-1),
+ )
+
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x0, b1, e1, s1),
+ R.Tensor("float32", ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x1, b1, e1, s1),
+ R.Tensor("float32", ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x2, b1, e1, s1),
+ R.Tensor("float32", ndim=-1),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x3, b1, e1, s1),
+ R.Tensor(ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x4, b1, e1, s1),
+ R.Tensor(ndim=4),
+ )
+ _check_inference(
+ bb,
+ relax.op.dynamic_strided_slice(x5, b1, e1, s1),
+ R.Tensor(ndim=-1),
+ )
+
+
+def test_dynamic_strided_slice_infer_struct_info_arg_wrong_dtype():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32"))
+ b0 = relax.Var("begin", R.Tensor((4,), "float32"))
+ e0 = relax.Var("end", R.Tensor((4,), "float32"))
+ s0 = relax.Var("stride", R.Tensor((4,), "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.strided_slice(x0, b0, e0, s0))
+
+
+def test_dynamic_strided_slice_infer_struct_info_arg_wrong_shape_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32"))
+ m = tir.Var("m", "int64")
+ # invalid arg
+ b0 = relax.Var("begin", R.Tensor("int64", ndim=2))
+ b1 = relax.Var("begin", R.Tensor((1,), "int64"))
+ b2 = relax.Var("begin", R.Tensor((2, 2), "int64"))
+ b3 = relax.Var("begin", R.Tensor((m,), "int64"))
+ # valid args
+ e0 = relax.Var("end", R.Tensor((4,), "int64"))
+ s0 = relax.Var("stride", R.Tensor((4,), "int64"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.strided_slice(x0, b0, e0, s0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.strided_slice(x0, b1, e0, s0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.strided_slice(x0, b2, e0, s0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.strided_slice(x0, b3, e0, s0))
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index 39224240ef..d7c0b54af2 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -210,6 +210,495 @@ def test_strided_slice_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_dynamic_strided_slice():
+ # fmt: off
+ @tvm.script.ir_module
+ class DynamicStridedSlice:
+ @R.function
+ def main(x: R.Tensor((8, 9, 10, 10), "float32"), begin:
R.Tensor((4,),"int64"), end: R.Tensor((4,),"int64"), strides:
R.Tensor((4,),"int64")) -> R.Tensor("float32", ndim=4):
+ gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x,
begin, end, strides)
+ return gv
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def dynamic_strided_slice(
+ rxplaceholder: T.Buffer(
+ (T.int64(8), T.int64(9), T.int64(10), T.int64(10)), "float32"
+ ),
+ rxplaceholder_1: T.Buffer((T.int64(4),), "int64"),
+ rxplaceholder_2: T.Buffer((T.int64(4),), "int64"),
+ rxplaceholder_3: T.Buffer((T.int64(4),), "int64"),
+ var_T_strided_slice_dynamic: T.handle,
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ s, s_1, s_2, s_3 = T.int64(), T.int64(), T.int64(), T.int64()
+ T_strided_slice_dynamic = T.match_buffer(
+ var_T_strided_slice_dynamic, (s, s_1, s_2, s_3)
+ )
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3 in T.grid(s, s_1, s_2, s_3):
+ with T.block("T_strided_slice_dynamic"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(
+ rxplaceholder[
+ T.min(rxplaceholder_1[T.int64(0)], T.int64(7))
+ + v_ax0 * rxplaceholder_3[T.int64(0)],
+ T.min(rxplaceholder_1[T.int64(1)], T.int64(8))
+ + v_ax1 * rxplaceholder_3[T.int64(1)],
+ T.min(rxplaceholder_1[T.int64(2)], T.int64(9))
+ + v_ax2 * rxplaceholder_3[T.int64(2)],
+ T.min(rxplaceholder_1[T.int64(3)], T.int64(9))
+ + v_ax3 * rxplaceholder_3[T.int64(3)],
+ ],
+ rxplaceholder_1[T.int64(0) : T.int64(4)],
+ rxplaceholder_3[T.int64(0) : T.int64(4)],
+ )
+ T.writes(T_strided_slice_dynamic[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_strided_slice_dynamic[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[
+ T.min(rxplaceholder_1[T.int64(0)], T.int64(7))
+ + v_ax0 * rxplaceholder_3[T.int64(0)],
+ T.min(rxplaceholder_1[T.int64(1)], T.int64(8))
+ + v_ax1 * rxplaceholder_3[T.int64(1)],
+ T.min(rxplaceholder_1[T.int64(2)], T.int64(9))
+ + v_ax2 * rxplaceholder_3[T.int64(2)],
+ T.min(rxplaceholder_1[T.int64(3)], T.int64(9))
+ + v_ax3 * rxplaceholder_3[T.int64(3)],
+ ]
+
+ @T.prim_func
+ def shape_func(
+ rxplaceholder: T.Buffer(
+ (T.int64(8), T.int64(9), T.int64(10), T.int64(10)), "float32"
+ ),
+ rxplaceholder_1: T.Buffer((T.int64(4),), "int64"),
+ rxplaceholder_2: T.Buffer((T.int64(4),), "int64"),
+ rxplaceholder_3: T.Buffer((T.int64(4),), "int64"),
+ T_shape_func_strided_slice_dynamic: T.Buffer((T.int64(4),),
"int64"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for i in range(T.int64(4)):
+ with T.block("T_shape_func_strided_slice_dynamic"):
+ v_i = T.axis.spatial(T.int64(4), i)
+ T.reads(
+ rxplaceholder_3[v_i], rxplaceholder_1[v_i],
rxplaceholder_2[v_i]
+ )
+ T.writes(T_shape_func_strided_slice_dynamic[v_i])
+ T_shape_func_strided_slice_dynamic[v_i] = T.Select(
+ rxplaceholder_3[v_i] < T.int64(0),
+ (
+ T.min(
+ T.max(
+ T.Select(
+ rxplaceholder_1[v_i] < T.int64(0),
+ rxplaceholder_1[v_i]
+ + T.Select(
+ v_i == T.int64(3),
+ T.int64(10),
+ T.Select(
+ v_i == T.int64(2),
+ T.int64(10),
+ T.Select(
+ v_i == T.int64(1),
+ T.int64(9),
+ T.Select(
+ v_i == T.int64(0),
+ T.int64(8),
+ T.int64(-1),
+ ),
+ ),
+ ),
+ ),
+ rxplaceholder_1[v_i],
+ ),
+ T.int64(-1),
+ ),
+ T.Select(
+ v_i == T.int64(3),
+ T.int64(10),
+ T.Select(
+ v_i == T.int64(2),
+ T.int64(10),
+ T.Select(
+ v_i == T.int64(1),
+ T.int64(9),
+ T.Select(
+ v_i == T.int64(0), T.int64(8),
T.int64(-1)
+ ),
+ ),
+ ),
+ )
+ - T.int64(1),
+ )
+ - T.min(
+ T.max(
+ T.Select(
+ rxplaceholder_2[v_i] < T.int64(0),
+ rxplaceholder_2[v_i]
+ + T.Select(
+ v_i == T.int64(3),
+ T.int64(10),
+ T.Select(
+ v_i == T.int64(2),
+ T.int64(10),
+ T.Select(
+ v_i == T.int64(1),
+ T.int64(9),
+ T.Select(
+ v_i == T.int64(0),
+ T.int64(8),
+ T.int64(-1),
+ ),
+ ),
+ ),
+ ),
+ rxplaceholder_2[v_i],
+ ),
+ T.int64(-1),
+ ),
+ T.Select(
+ v_i == T.int64(3),
+ T.int64(10),
+ T.Select(
+ v_i == T.int64(2),
+ T.int64(10),
+ T.Select(
+ v_i == T.int64(1),
+ T.int64(9),
+ T.Select(
+ v_i == T.int64(0), T.int64(8),
T.int64(-1)
+ ),
+ ),
+ ),
+ )
+ - T.int64(1),
+ )
+ - rxplaceholder_3[v_i]
+ - T.int64(1)
+ )
+ // (rxplaceholder_3[v_i] * T.int64(-1)),
+ (
+ T.min(
+ T.max(
+ T.Select(
+ rxplaceholder_2[v_i] < T.int64(0),
+ rxplaceholder_2[v_i]
+ + T.Select(
+ v_i == T.int64(3),
+ T.int64(10),
+ T.Select(
+ v_i == T.int64(2),
+ T.int64(10),
+ T.Select(
+ v_i == T.int64(1),
+ T.int64(9),
+ T.Select(
+ v_i == T.int64(0),
+ T.int64(8),
+ T.int64(-1),
+ ),
+ ),
+ ),
+ ),
+ rxplaceholder_2[v_i],
+ ),
+ T.int64(0),
+ ),
+ T.Select(
+ v_i == T.int64(3),
+ T.int64(10),
+ T.Select(
+ v_i == T.int64(2),
+ T.int64(10),
+ T.Select(
+ v_i == T.int64(1),
+ T.int64(9),
+ T.Select(
+ v_i == T.int64(0), T.int64(8),
T.int64(-1)
+ ),
+ ),
+ ),
+ ),
+ )
+ + rxplaceholder_3[v_i]
+ - T.min(
+ T.max(
+ T.Select(
+ rxplaceholder_1[v_i] < T.int64(0),
+ rxplaceholder_1[v_i]
+ + T.Select(
+ v_i == T.int64(3),
+ T.int64(10),
+ T.Select(
+ v_i == T.int64(2),
+ T.int64(10),
+ T.Select(
+ v_i == T.int64(1),
+ T.int64(9),
+ T.Select(
+ v_i == T.int64(0),
+ T.int64(8),
+ T.int64(-1),
+ ),
+ ),
+ ),
+ ),
+ rxplaceholder_1[v_i],
+ ),
+ T.int64(0),
+ ),
+ T.Select(
+ v_i == T.int64(3),
+ T.int64(10),
+ T.Select(
+ v_i == T.int64(2),
+ T.int64(10),
+ T.Select(
+ v_i == T.int64(1),
+ T.int64(9),
+ T.Select(
+ v_i == T.int64(0), T.int64(8),
T.int64(-1)
+ ),
+ ),
+ ),
+ ),
+ )
+ - T.int64(1)
+ )
+ // rxplaceholder_3[v_i],
+ )
+
+ @R.function
+ def main(
+ x: R.Tensor((8, 9, 10, 10), dtype="float32"),
+ begin: R.Tensor((4,), dtype="int64"),
+ end: R.Tensor((4,), dtype="int64"),
+ strides: R.Tensor((4,), dtype="int64"),
+ ) -> R.Tensor(dtype="float32", ndim=4):
+ s = T.int64()
+ s_1 = T.int64()
+ s_2 = T.int64()
+ s_3 = T.int64()
+ gv = R.call_tir(
+ Expected.shape_func,
+ (x, begin, end, strides),
+ out_sinfo=R.Tensor((4,), dtype="int64"),
+ )
+ gv1: R.Shape(ndim=4) = R.call_packed(
+ "vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=4),)
+ )
+ gv2: R.Shape([s, s_1, s_2, s_3]) = R.match_cast(
+ gv1, R.Shape([s, s_1, s_2, s_3])
+ )
+ gv_1 = R.call_tir(
+ Expected.dynamic_strided_slice,
+ (x, begin, end, strides),
+ out_sinfo=R.Tensor((s, s_1, s_2, s_3), dtype="float32"),
+ )
+ return gv_1
+ # fmt: on
+ mod = LegalizeOps()(DynamicStridedSlice)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_dynamic_strided_slice_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class DynamicStridedSlice:
+ @R.function
+ def main(x: R.Tensor((10, "n"), "float32"), begin:R.Tensor((2,),
"int64"), end:R.Tensor((2,), "int64"), strides:R.Tensor((2,), "int64")) ->
R.Tensor("float32", ndim=2):
+ n = T.int64()
+ gv: R.Tensor("float32", ndim=2) = R.dynamic_strided_slice(x,
begin, end, strides)
+ return gv
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def dynamic_strided_slice(
+ var_rxplaceholder: T.handle,
+ rxplaceholder: T.Buffer((T.int64(2),), "int64"),
+ rxplaceholder_1: T.Buffer((T.int64(2),), "int64"),
+ rxplaceholder_2: T.Buffer((T.int64(2),), "int64"),
+ var_T_strided_slice_dynamic: T.handle,
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ n = T.int64()
+ rxplaceholder_3 = T.match_buffer(var_rxplaceholder, (T.int64(10),
n))
+ s, s_1 = T.int64(), T.int64()
+ T_strided_slice_dynamic =
T.match_buffer(var_T_strided_slice_dynamic, (s, s_1))
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(s, s_1):
+ with T.block("T_strided_slice_dynamic"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(
+ rxplaceholder_3[
+ T.min(rxplaceholder[T.int64(0)], T.int64(9))
+ + v_ax0 * rxplaceholder_2[T.int64(0)],
+ T.min(rxplaceholder[T.int64(1)], n - T.int64(1))
+ + v_ax1 * rxplaceholder_2[T.int64(1)],
+ ],
+ rxplaceholder[T.int64(0) : T.int64(2)],
+ rxplaceholder_2[T.int64(0) : T.int64(2)],
+ )
+ T.writes(T_strided_slice_dynamic[v_ax0, v_ax1])
+ T_strided_slice_dynamic[v_ax0, v_ax1] = rxplaceholder_3[
+ T.min(rxplaceholder[T.int64(0)], T.int64(9))
+ + v_ax0 * rxplaceholder_2[T.int64(0)],
+ T.min(rxplaceholder[T.int64(1)], n - T.int64(1))
+ + v_ax1 * rxplaceholder_2[T.int64(1)],
+ ]
+
+ @T.prim_func
+ def shape_func(
+ var_rxplaceholder: T.handle,
+ rxplaceholder: T.Buffer((T.int64(2),), "int64"),
+ rxplaceholder_1: T.Buffer((T.int64(2),), "int64"),
+ rxplaceholder_2: T.Buffer((T.int64(2),), "int64"),
+ T_shape_func_strided_slice_dynamic: T.Buffer((T.int64(2),),
"int64"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ n = T.int64()
+ rxplaceholder_3 = T.match_buffer(var_rxplaceholder, (T.int64(10),
n))
+ # with T.block("root"):
+ for i in range(T.int64(2)):
+ with T.block("T_shape_func_strided_slice_dynamic"):
+ v_i = T.axis.spatial(T.int64(2), i)
+ T.reads(rxplaceholder_2[v_i], rxplaceholder[v_i],
rxplaceholder_1[v_i])
+ T.writes(T_shape_func_strided_slice_dynamic[v_i])
+ T_shape_func_strided_slice_dynamic[v_i] = T.Select(
+ rxplaceholder_2[v_i] < T.int64(0),
+ (
+ T.min(
+ T.max(
+ T.Select(
+ rxplaceholder[v_i] < T.int64(0),
+ rxplaceholder[v_i]
+ + T.Select(
+ v_i == T.int64(1),
+ n,
+ T.Select(
+ v_i == T.int64(0),
T.int64(10), T.int64(-1)
+ ),
+ ),
+ rxplaceholder[v_i],
+ ),
+ T.int64(-1),
+ ),
+ T.Select(
+ v_i == T.int64(1),
+ n,
+ T.Select(v_i == T.int64(0), T.int64(10),
T.int64(-1)),
+ )
+ - T.int64(1),
+ )
+ - T.min(
+ T.max(
+ T.Select(
+ rxplaceholder_1[v_i] < T.int64(0),
+ rxplaceholder_1[v_i]
+ + T.Select(
+ v_i == T.int64(1),
+ n,
+ T.Select(
+ v_i == T.int64(0),
T.int64(10), T.int64(-1)
+ ),
+ ),
+ rxplaceholder_1[v_i],
+ ),
+ T.int64(-1),
+ ),
+ T.Select(
+ v_i == T.int64(1),
+ n,
+ T.Select(v_i == T.int64(0), T.int64(10),
T.int64(-1)),
+ )
+ - T.int64(1),
+ )
+ - rxplaceholder_2[v_i]
+ - T.int64(1)
+ )
+ // (rxplaceholder_2[v_i] * T.int64(-1)),
+ (
+ T.min(
+ T.max(
+ T.Select(
+ rxplaceholder_1[v_i] < T.int64(0),
+ rxplaceholder_1[v_i]
+ + T.Select(
+ v_i == T.int64(1),
+ n,
+ T.Select(
+ v_i == T.int64(0),
T.int64(10), T.int64(-1)
+ ),
+ ),
+ rxplaceholder_1[v_i],
+ ),
+ T.int64(0),
+ ),
+ T.Select(
+ v_i == T.int64(1),
+ n,
+ T.Select(v_i == T.int64(0), T.int64(10),
T.int64(-1)),
+ ),
+ )
+ + rxplaceholder_2[v_i]
+ - T.min(
+ T.max(
+ T.Select(
+ rxplaceholder[v_i] < T.int64(0),
+ rxplaceholder[v_i]
+ + T.Select(
+ v_i == T.int64(1),
+ n,
+ T.Select(
+ v_i == T.int64(0),
T.int64(10), T.int64(-1)
+ ),
+ ),
+ rxplaceholder[v_i],
+ ),
+ T.int64(0),
+ ),
+ T.Select(
+ v_i == T.int64(1),
+ n,
+ T.Select(v_i == T.int64(0), T.int64(10),
T.int64(-1)),
+ ),
+ )
+ - T.int64(1)
+ )
+ // rxplaceholder_2[v_i],
+ )
+
+ @R.function
+ def main(
+ x: R.Tensor((10, "n"), dtype="float32"),
+ begin: R.Tensor((2,), dtype="int64"),
+ end: R.Tensor((2,), dtype="int64"),
+ strides: R.Tensor((2,), dtype="int64"),
+ ) -> R.Tensor(dtype="float32", ndim=2):
+ n = T.int64()
+ s = T.int64()
+ s_1 = T.int64()
+ gv = R.call_tir(
+ Expected.shape_func,
+ (x, begin, end, strides),
+ out_sinfo=R.Tensor((2,), dtype="int64"),
+ )
+ gv1: R.Shape(ndim=2) = R.call_packed(
+ "vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=2),)
+ )
+ gv2: R.Shape([s, s_1]) = R.match_cast(gv1, R.Shape([s, s_1]))
+ gv_1 = R.call_tir(
+ Expected.dynamic_strided_slice,
+ (x, begin, end, strides),
+ out_sinfo=R.Tensor((s, s_1), dtype="float32"),
+ )
+ return gv_1
+ # fmt: on
+
+ mod = LegalizeOps()(DynamicStridedSlice)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
##################### Linear algebra #####################
diff --git a/tests/python/topi/python/test_topi_transform.py
b/tests/python/topi/python/test_topi_transform.py
index 5866ffd5f7..ac69a2c85e 100644
--- a/tests/python/topi/python/test_topi_transform.py
+++ b/tests/python/topi/python/test_topi_transform.py
@@ -457,6 +457,51 @@ def verify_dynamic_strided_slice(in_shape, begin, end,
strides=None):
check_device(target)
+def verify_relax_dynamic_strided_slice(in_shape, begin, end, strides,
output_shape):
+ A = te.placeholder(shape=in_shape, name="A")
+ Begin = te.placeholder(shape=[len(in_shape)], name="begin", dtype="int64")
+ End = te.placeholder(shape=[len(in_shape)], name="end", dtype="int64")
+ Strides = te.placeholder(shape=[len(in_shape)], name="strides",
dtype="int64")
+
+ B = topi.dynamic_strided_slice(A, Begin, End, Strides, output_shape) + 1
+
+ OutShape = topi.shape_func_dynamic_strided_slice(A, Begin, End, Strides)
+
+ def check_device(target):
+ dev = tvm.device(target, 0)
+ if not tvm.testing.device_enabled(target):
+ print("Skip because %s is not enabled" % target)
+ return
+ print("Running on target: %s" % target)
+ x_np = np.random.uniform(size=in_shape).astype(A.dtype)
+ out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end,
strides) + 1
+ data_nd = tvm.nd.array(x_np, dev)
+ out_nd = tvm.nd.empty(out_npy.shape, device=dev, dtype=A.dtype)
+ begin_nd = tvm.nd.array(np.array(begin).astype("int64"), dev)
+ end_nd = tvm.nd.array(np.array(end).astype("int64"), dev)
+ strides_nd = tvm.nd.array(np.array(strides).astype("int64"), dev)
+
+ if target == "llvm":
+ # Check shape func
+ s = tvm.te.create_schedule(OutShape.op)
+ bar = tvm.build(
+ s, [A, Begin, End, Strides, OutShape], target,
name="shape_func_stride_slice"
+ )
+ out_shape_nd = tvm.nd.empty((len(out_npy.shape),), device=dev,
dtype="int64")
+ bar(data_nd, begin_nd, end_nd, strides_nd, out_shape_nd)
+
+ tvm.testing.assert_allclose(out_shape_nd.numpy(), output_shape)
+
+ with tvm.target.Target(target):
+ s = tvm.topi.testing.get_injective_schedule(target)(B)
+ foo = tvm.build(s, [A, Begin, End, Strides, B], target,
name="stride_slice")
+ foo(data_nd, begin_nd, end_nd, strides_nd, out_nd)
+ tvm.testing.assert_allclose(out_nd.numpy(), out_npy)
+
+ for target in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
+ check_device(target)
+
+
def verify_strided_set(in_shape, v_shape, begin, end, strides=None):
A = te.placeholder(shape=in_shape, name="A")
V = te.placeholder(shape=v_shape, name="V")
@@ -859,6 +904,15 @@ def test_dynamic_strided_slice():
verify_dynamic_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3])
[email protected]_gpu
+def test_relax_dynamic_strided_slice():
+ verify_relax_dynamic_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1,
-1, 2], [3, 1, 2])
+ verify_relax_dynamic_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1,
1], [1, 3, 3])
+ verify_relax_dynamic_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1,
2], [1, 2, 2])
+ verify_relax_dynamic_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [1, 1,
1], [2, 3, 3])
+ verify_relax_dynamic_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3], [1, 1,
1], [1, 0, 3])
+
+
@tvm.testing.uses_gpu
def test_strided_set():
verify_strided_set((3, 4, 3), (3, 2, 2), [0, 3, 0], [4, 1, 4], [1, -1, 2])