This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 684218f72d [Unity][Topi] Handle variable begin/end axes in 
topi.strided_slice (#15520)
684218f72d is described below

commit 684218f72d5d4856188e99f8073855ed9a722d55
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Aug 17 09:19:10 2023 -0500

    [Unity][Topi] Handle variable begin/end axes in topi.strided_slice (#15520)
    
    * [Unity][Topi] Handle variable begin/end axes in topi.strided_slice
    
    Prior to this commit, the checks for a dynamic strided slice were
    implemented differently on the python and C++ sides.  In Python,
    strided slice delegated to dynamic strided slice whenever the
    begin/end was a `te.Tensor`.  In C++, strided slice delegated to
    dynamic strided slice whenever the begin/end was not a list of static
    integers.  As a result, a list of TIR expressions would pass the
    Python checks for strided slice, but be delegated to dynamic strided
    slice within the C++ implementation.  Because dynamic strided slice
    did not support the `axes` argument, any `axes` specified would be
    silently ignored.
    
    This commit implements `dynamic_strided_slice_with_axes`, and
    delegates to it or the previous `dynamic_strided_slice` depending on
    whether the `axes` are present.
    
    This PR is a follow-up to #15450, which allowed dynamic begin/end
    expressions in `relax.op.strided_slice`.  This issue was first noticed
    in a relax graph with variable (but compile-time) bounds on begin/end.
    
    * Updated impacted unit test
---
 include/tvm/topi/transform.h                    | 54 +++++++++++++++++++++++++
 src/topi/transform.cc                           | 10 +++--
 tests/python/relax/test_op_index.py             |  2 +-
 tests/python/topi/python/test_topi_transform.py | 46 ++++++++++++++++++++-
 4 files changed, 107 insertions(+), 5 deletions(-)

diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index e2b19dd533..15f755df59 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -633,6 +633,60 @@ inline Array<Tensor> split(const Tensor& x, 
Array<PrimExpr> split_indices, int a
   return result;
 }
 
+/*!
+ * \brief strided_slice of a tensor where begin/end/stride can be mixed static 
and dynamic
+ *
+ * \param x The input tensor
+ * \param begin The indices to begin with in the slicing
+ * \param end Indices indicating end of the slice
+ * \param strides Specifies the stride values, it can be negative
+ * in that case, the input tensor will be reversed in that particular axis
+ * \param axes Specifies which axes will be updated.
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return A Tensor whose op member is the dynamic_strided_slice operation
+ */
+inline Tensor dynamic_strided_slice_with_axes(
+    const Tensor& x, const Array<PrimExpr>& begin, const Array<PrimExpr>& end,
+    const Array<PrimExpr>& strides, const Array<Integer>& axes,
+    std::string name = "T_dynamic_strided_slice_with_axes", std::string tag = 
kInjective) {
+  const size_t src_tensor_dim = x->shape.size();
+  ICHECK_EQ(begin.size(), end.size());
+  ICHECK_EQ(begin.size(), strides.size());
+  ICHECK_EQ(begin.size(), axes.size());
+  ICHECK_LE(begin.size(), src_tensor_dim);
+
+  for (const auto& axis_imm : axes) {
+    int axis = axis_imm->value;
+    ICHECK_LT(axis, src_tensor_dim);
+  }
+
+  arith::Analyzer analyzer;
+
+  Array<PrimExpr> out_shape = x->shape;
+  for (size_t i = 0; i < begin.size(); i++) {
+    int axis = axes[i]->value;
+    PrimExpr new_shape = analyzer.Simplify(indexdiv(end[i] - begin[i], 
strides[i]));
+    out_shape.Set(axis, new_shape);
+  }
+
+  return te::compute(
+      out_shape,
+      [&](const Array<tvm::tir::Var>& indices) {
+        Array<PrimExpr> real_indices = indices.Map([](const auto& var) -> 
PrimExpr { return var; });
+
+        for (size_t i = 0; i < begin.size(); i++) {
+          int axis = axes[i]->value;
+          PrimExpr new_index = indices[axis] * strides[i] + begin[i];
+          real_indices.Set(axis, new_index);
+        }
+
+        return x(real_indices);
+      },
+      name, tag);
+}
+
 /*!
  * \brief strided_slice of a tensor where begin/end/stride can be mixed static 
and dynamic
  *
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index 906f2b9115..fc91f4bbf1 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -178,19 +178,23 @@ 
TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue*
   Array<PrimExpr> begin = args[1];
   Array<PrimExpr> end = args[2];
   Array<PrimExpr> strides = args[3];
+  Array<Integer> axes = args[4];
   if (IsConstIntArray(begin) && IsConstIntArray(end) && 
IsConstIntArray(strides)) {
     Array<Integer> begin_static = args[1];
     Array<Integer> end_static = args[2];
     Array<Integer> strides_static = args[3];
-    Array<Integer> axes = args[4];
     std::string slice_mode = args[5];
-    if (axes.size() > 0) {
+    if (axes.size()) {
       *rv = strided_slice_with_axes(x, begin_static, end_static, 
strides_static, axes, slice_mode);
     } else {
       *rv = strided_slice(x, begin_static, end_static, strides_static, 
slice_mode);
     }
   } else {
-    *rv = dynamic_strided_slice(x, begin, end, strides);
+    if (axes.size()) {
+      *rv = dynamic_strided_slice_with_axes(x, begin, end, strides, axes);
+    } else {
+      *rv = dynamic_strided_slice(x, begin, end, strides);
+    }
   }
 });
 
diff --git a/tests/python/relax/test_op_index.py 
b/tests/python/relax/test_op_index.py
index e31c383ce7..fffd31f6f7 100644
--- a/tests/python/relax/test_op_index.py
+++ b/tests/python/relax/test_op_index.py
@@ -964,7 +964,7 @@ def test_legalize_dynamic_begin_end():
             for iters in T.grid(*B.shape):
                 with T.block("T_dynamic_strided_slice"):
                     i, j = T.axis.remap("SS", iters)
-                    B[i, j] = A[T.min(index, T.int64(15)) + i, j]
+                    B[i, j] = A[i + index, j]
 
     after = tvm.relax.transform.LegalizeOps()(before)
     tvm.ir.assert_structural_equal(expected, after)
diff --git a/tests/python/topi/python/test_topi_transform.py 
b/tests/python/topi/python/test_topi_transform.py
index e182c8793d..862f4a66ed 100644
--- a/tests/python/topi/python/test_topi_transform.py
+++ b/tests/python/topi/python/test_topi_transform.py
@@ -19,10 +19,12 @@ import numpy as np
 import pytest
 import tvm
 from tvm import te
+from tvm import tir
 from tvm import topi
 from tvm import relay
 import tvm.topi.testing
 from tvm.contrib.nvcc import have_fp16
+from tvm.script import tir as T
 
 import tvm.testing
 
@@ -892,7 +894,49 @@ def test_strided_slice():
     verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
     verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3])
     verify_strided_slice((3, 4, 3), [0, 0, 0], [None, None, None])
-    verify_strided_slice((3, 4, 3), [0], [2], None, axes=[1])
+
+
+def test_strided_slice_with_dynamic_bounds():
+    """The begin/end of strided_slice can be a PrimExpr
+
+    Where topi.dynamic_strided_slice uses begin/end values provided at
+    runtime, strided_slice takes begin/end values at compile-time.
+    However, these begin/end values may depend on dynamic variables.
+    Previously, these resulted in dispatch to
+    `tvm::topi::dynamic_strided_slice`, ignoring the `axes` argument.
+    """
+    A = te.placeholder(shape=[16, 32, 64], name="A")
+    begins = [tir.Var("begin1", "int32"), tir.Var("begin2", "int32")]
+    ends = [tir.Var("end1", "int32"), tir.Var("end2", "int32")]
+    strides = [1, 1]
+    axes = [2, 1]
+
+    # Dummy tensor to provide begin/end variables in PrimFunc scope.
+    # Outside of a test case, these would typically be provided
+    # through another means, or bound to a static value at a later
+    # point.
+    Dummy = te.placeholder(shape=[*begins, *ends], name="Dummy")
+
+    B = topi.strided_slice(A, begins, ends, strides, axes)
+
+    func = te.create_prim_func([A, Dummy, B]).without_attr("global_symbol")
+
+    @T.prim_func(private=True)
+    def expected(
+        A: T.Buffer((16, 32, 64), "float32"),
+        var_Dummy: T.handle,
+        B_handle: T.handle,
+    ):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        begin1, begin2, end1, end2 = T.int32(), T.int32(), T.int32(), T.int32()
+        Dummy = T.match_buffer(var_Dummy, (begin1, begin2, end1, end2))
+        B = T.match_buffer(B_handle, (16, end2 - begin2, end1 - begin1))
+        for iters in T.grid(*B.shape):
+            with T.block("T_dynamic_strided_slice_with_axes"):
+                i, j, k = T.axis.remap("SSS", iters)
+                B[i, j, k] = A[i, j + begin2, k + begin1]
+
+    tvm.ir.assert_structural_equal(expected, func)
 
 
 @tvm.testing.uses_gpu

Reply via email to