This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 684218f72d [Unity][Topi] Handle variable begin/end axes in
topi.strided_slice (#15520)
684218f72d is described below
commit 684218f72d5d4856188e99f8073855ed9a722d55
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Aug 17 09:19:10 2023 -0500
[Unity][Topi] Handle variable begin/end axes in topi.strided_slice (#15520)
* [Unity][Topi] Handle variable begin/end axes in topi.strided_slice
Prior to this commit, the checks for a dynamic strided slice were
implemented differently on the python and C++ sides. In Python,
strided slice delegated to dynamic strided slice whenever the
begin/end was a `te.Tensor`. In C++, strided slice delegated to
dynamic strided slice whenever the begin/end was not a list of static
integers. As a result, a list of TIR expressions would pass the
Python checks for strided slice, but be delegated to dynamic strided
slice within the C++ implementation. Because dynamic strided slice
did not support the `axes` argument, any `axes` specified would be
silently ignored.
This commit implements `dynamic_strided_slice_with_axes`, and
delegates to it or the previous `dynamic_strided_slice` depending on
whether the `axes` are present.
This PR is a follow-up to #15450, which allowed dynamic begin/end
expressions in `relax.op.strided_slice`. This issue was first noticed
in a relax graph with variable (but compile-time) bounds on begin/end.
* Updated impacted unit test
---
include/tvm/topi/transform.h | 54 +++++++++++++++++++++++++
src/topi/transform.cc | 10 +++--
tests/python/relax/test_op_index.py | 2 +-
tests/python/topi/python/test_topi_transform.py | 46 ++++++++++++++++++++-
4 files changed, 107 insertions(+), 5 deletions(-)
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index e2b19dd533..15f755df59 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -633,6 +633,60 @@ inline Array<Tensor> split(const Tensor& x,
Array<PrimExpr> split_indices, int a
return result;
}
+/*!
+ * \brief strided_slice of a tensor where begin/end/stride can be mixed static
and dynamic
+ *
+ * \param x The input tensor
+ * \param begin The indices to begin with in the slicing
+ * \param end Indices indicating end of the slice
+ * \param strides Specifies the stride values, it can be negative
+ * in that case, the input tensor will be reversed in that particular axis
+ * \param axes Specifies which axes will be updated.
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the dynamic_strided_slice operation
+ */
+inline Tensor dynamic_strided_slice_with_axes(
+ const Tensor& x, const Array<PrimExpr>& begin, const Array<PrimExpr>& end,
+ const Array<PrimExpr>& strides, const Array<Integer>& axes,
+ std::string name = "T_dynamic_strided_slice_with_axes", std::string tag =
kInjective) {
+ const size_t src_tensor_dim = x->shape.size();
+ ICHECK_EQ(begin.size(), end.size());
+ ICHECK_EQ(begin.size(), strides.size());
+ ICHECK_EQ(begin.size(), axes.size());
+ ICHECK_LE(begin.size(), src_tensor_dim);
+
+ for (const auto& axis_imm : axes) {
+ int axis = axis_imm->value;
+ ICHECK_LT(axis, src_tensor_dim);
+ }
+
+ arith::Analyzer analyzer;
+
+ Array<PrimExpr> out_shape = x->shape;
+ for (size_t i = 0; i < begin.size(); i++) {
+ int axis = axes[i]->value;
+ PrimExpr new_shape = analyzer.Simplify(indexdiv(end[i] - begin[i],
strides[i]));
+ out_shape.Set(axis, new_shape);
+ }
+
+ return te::compute(
+ out_shape,
+ [&](const Array<tvm::tir::Var>& indices) {
+ Array<PrimExpr> real_indices = indices.Map([](const auto& var) ->
PrimExpr { return var; });
+
+ for (size_t i = 0; i < begin.size(); i++) {
+ int axis = axes[i]->value;
+ PrimExpr new_index = indices[axis] * strides[i] + begin[i];
+ real_indices.Set(axis, new_index);
+ }
+
+ return x(real_indices);
+ },
+ name, tag);
+}
+
/*!
* \brief strided_slice of a tensor where begin/end/stride can be mixed static
and dynamic
*
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index 906f2b9115..fc91f4bbf1 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -178,19 +178,23 @@
TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue*
Array<PrimExpr> begin = args[1];
Array<PrimExpr> end = args[2];
Array<PrimExpr> strides = args[3];
+ Array<Integer> axes = args[4];
if (IsConstIntArray(begin) && IsConstIntArray(end) &&
IsConstIntArray(strides)) {
Array<Integer> begin_static = args[1];
Array<Integer> end_static = args[2];
Array<Integer> strides_static = args[3];
- Array<Integer> axes = args[4];
std::string slice_mode = args[5];
- if (axes.size() > 0) {
+ if (axes.size()) {
*rv = strided_slice_with_axes(x, begin_static, end_static,
strides_static, axes, slice_mode);
} else {
*rv = strided_slice(x, begin_static, end_static, strides_static,
slice_mode);
}
} else {
- *rv = dynamic_strided_slice(x, begin, end, strides);
+ if (axes.size()) {
+ *rv = dynamic_strided_slice_with_axes(x, begin, end, strides, axes);
+ } else {
+ *rv = dynamic_strided_slice(x, begin, end, strides);
+ }
}
});
diff --git a/tests/python/relax/test_op_index.py
b/tests/python/relax/test_op_index.py
index e31c383ce7..fffd31f6f7 100644
--- a/tests/python/relax/test_op_index.py
+++ b/tests/python/relax/test_op_index.py
@@ -964,7 +964,7 @@ def test_legalize_dynamic_begin_end():
for iters in T.grid(*B.shape):
with T.block("T_dynamic_strided_slice"):
i, j = T.axis.remap("SS", iters)
- B[i, j] = A[T.min(index, T.int64(15)) + i, j]
+ B[i, j] = A[i + index, j]
after = tvm.relax.transform.LegalizeOps()(before)
tvm.ir.assert_structural_equal(expected, after)
diff --git a/tests/python/topi/python/test_topi_transform.py
b/tests/python/topi/python/test_topi_transform.py
index e182c8793d..862f4a66ed 100644
--- a/tests/python/topi/python/test_topi_transform.py
+++ b/tests/python/topi/python/test_topi_transform.py
@@ -19,10 +19,12 @@ import numpy as np
import pytest
import tvm
from tvm import te
+from tvm import tir
from tvm import topi
from tvm import relay
import tvm.topi.testing
from tvm.contrib.nvcc import have_fp16
+from tvm.script import tir as T
import tvm.testing
@@ -892,7 +894,49 @@ def test_strided_slice():
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3])
verify_strided_slice((3, 4, 3), [0, 0, 0], [None, None, None])
- verify_strided_slice((3, 4, 3), [0], [2], None, axes=[1])
+
+
+def test_strided_slice_with_dynamic_bounds():
+ """The begin/end of strided_slice can be a PrimExpr
+
+ Where topi.dynamic_strided_slice uses begin/end values provided at
+ runtime, strided_slice takes begin/end values at compile-time.
+ However, these begin/end values may depend on dynamic variables.
+ Previously, these resulted in dispatch to
+ `tvm::topi::dynamic_strided_slice`, ignoring the `axes` argument.
+ """
+ A = te.placeholder(shape=[16, 32, 64], name="A")
+ begins = [tir.Var("begin1", "int32"), tir.Var("begin2", "int32")]
+ ends = [tir.Var("end1", "int32"), tir.Var("end2", "int32")]
+ strides = [1, 1]
+ axes = [2, 1]
+
+ # Dummy tensor to provide begin/end variables in PrimFunc scope.
+ # Outside of a test case, these would typically be provided
+ # through another means, or bound to a static value at a later
+ # point.
+ Dummy = te.placeholder(shape=[*begins, *ends], name="Dummy")
+
+ B = topi.strided_slice(A, begins, ends, strides, axes)
+
+ func = te.create_prim_func([A, Dummy, B]).without_attr("global_symbol")
+
+ @T.prim_func(private=True)
+ def expected(
+ A: T.Buffer((16, 32, 64), "float32"),
+ var_Dummy: T.handle,
+ B_handle: T.handle,
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ begin1, begin2, end1, end2 = T.int32(), T.int32(), T.int32(), T.int32()
+ Dummy = T.match_buffer(var_Dummy, (begin1, begin2, end1, end2))
+ B = T.match_buffer(B_handle, (16, end2 - begin2, end1 - begin1))
+ for iters in T.grid(*B.shape):
+ with T.block("T_dynamic_strided_slice_with_axes"):
+ i, j, k = T.axis.remap("SSS", iters)
+ B[i, j, k] = A[i, j + begin2, k + begin1]
+
+ tvm.ir.assert_structural_equal(expected, func)
@tvm.testing.uses_gpu