This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e100a13737 [Unity] Fix legalizing strided slice (#16232)
e100a13737 is described below

commit e100a13737beae79d4f740945192e32f021236e9
Author: Wuwei Lin <[email protected]>
AuthorDate: Thu Dec 14 07:11:37 2023 -0800

    [Unity] Fix legalizing strided slice (#16232)
    
    * [Unity] Fix legalizing strided slice
    
    strided slice with symbolic bounds is supported in topi, this PR remove
    the check for dynamic shape in legalization. It also fixed calculation
    of output shape when stride != 1
---
 include/tvm/topi/transform.h                       |  4 +-
 python/tvm/relax/transform/legalize_ops/index.py   | 12 ---
 src/topi/transform.cc                              |  3 +-
 ..._transform_legalize_ops_index_linear_algebra.py | 90 +++++++++++++++++++++-
 4 files changed, 93 insertions(+), 16 deletions(-)

diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 68f5d36c82..a1f66a70ca 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -669,7 +669,7 @@ inline Tensor dynamic_strided_slice_with_axes(
   Array<PrimExpr> out_shape = x->shape;
   for (size_t i = 0; i < begin.size(); i++) {
     int axis = axes[i]->value;
-    PrimExpr new_shape = analyzer.Simplify(indexdiv(end[i] - begin[i], 
strides[i]));
+    PrimExpr new_shape = analyzer.Simplify(ceildiv(end[i] - begin[i], 
strides[i]));
     out_shape.Set(axis, new_shape);
   }
 
@@ -721,7 +721,7 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const 
Array<PrimExpr>& begi
     // Check ProducerLoad to keep backward compatibility for Relay.
     if (!begin[i]->IsInstance<ProducerLoadNode>() && 
!end[i]->IsInstance<ProducerLoadNode>() &&
         !strides[i]->IsInstance<ProducerLoadNode>()) {
-      out_shape.push_back(analyzer.Simplify(indexdiv(end[i] - begin[i], 
strides[i])));
+      out_shape.push_back(analyzer.Simplify(ceildiv(end[i] - begin[i], 
strides[i])));
     } else {
       out_shape.push_back(tvm::tir::Var("dim"));
     }
diff --git a/python/tvm/relax/transform/legalize_ops/index.py 
b/python/tvm/relax/transform/legalize_ops/index.py
index 13228c4805..5889da9487 100644
--- a/python/tvm/relax/transform/legalize_ops/index.py
+++ b/python/tvm/relax/transform/legalize_ops/index.py
@@ -16,8 +16,6 @@
 # under the License.
 # pylint: disable=invalid-name
 """Default legalization function for index operators."""
-import logging
-
 from tvm import topi, tir, te
 from ...op import call_pure_packed
 from ...block_builder import BlockBuilder
@@ -37,16 +35,6 @@ def _take(bb: BlockBuilder, call: Call) -> Expr:
 
 @register_legalize("relax.strided_slice")
 def _strided_slice(bb: BlockBuilder, call: Call) -> Expr:
-    if not all(
-        isinstance(call.args[0].struct_info.shape.values[i.value], tir.IntImm)
-        for i in call.attrs.axes
-    ):
-        logging.info(
-            "Cases where an axis with symbolic length is sliced are not able "
-            "to be legalized through TOPI"
-        )
-        return call
-
     strides = (
         [tir.IntImm("int64", 1)] * len(call.attrs.axes)
         if call.attrs.strides is None
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index fc91f4bbf1..a84e3dce50 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -179,7 +179,8 @@ 
TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue*
   Array<PrimExpr> end = args[2];
   Array<PrimExpr> strides = args[3];
   Array<Integer> axes = args[4];
-  if (IsConstIntArray(begin) && IsConstIntArray(end) && 
IsConstIntArray(strides)) {
+  if (IsConstIntArray(begin) && IsConstIntArray(end) && 
IsConstIntArray(strides) &&
+      IsConstIntArray(x->shape)) {
     Array<Integer> begin_static = args[1];
     Array<Integer> end_static = args[2];
     Array<Integer> strides_static = args[3];
diff --git 
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py 
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index 0c84daa572..0d1e969b35 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -168,10 +168,34 @@ def test_strided_slice_symbolic_sliced_axis():
             n = T.int64()
             gv: R.Tensor((2, n), "float32") = R.strided_slice(x, axes=[0], 
begin=[1], end=[8], strides=[3])
             return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def strided_slice(var_A: T.handle, 
var_T_dynamic_strided_slice_with_axes: T.handle):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            m, n = T.int64(), T.int64()
+            A = T.match_buffer(var_A, (m, n))
+            T_dynamic_strided_slice_with_axes = 
T.match_buffer(var_T_dynamic_strided_slice_with_axes, (T.int64(3), n))
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(T.int64(3), n):
+                with T.block("T_dynamic_strided_slice_with_axes"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0 * T.int64(3) + T.int64(1), v_ax1])
+                    T.writes(T_dynamic_strided_slice_with_axes[v_ax0, v_ax1])
+                    T_dynamic_strided_slice_with_axes[v_ax0, v_ax1] = A[v_ax0 
* T.int64(3) + T.int64(1), v_ax1]
+
+        @R.function
+        def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor((3, 
"n"), dtype="float32"):
+            n = T.int64()
+            m = T.int64()
+            cls = Expected
+            gv = R.call_tir(cls.strided_slice, (x,), out_sinfo=R.Tensor((3, 
n), dtype="float32"))
+            return gv
     # fmt: on
 
     mod = LegalizeOps()(StridedSlice)
-    tvm.ir.assert_structural_equal(mod, StridedSlice)
+    tvm.ir.assert_structural_equal(mod, Expected)
 
 
 def test_strided_slice_symbolic():
@@ -210,6 +234,70 @@ def test_strided_slice_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_strided_slice_symbolic_bound():
+    # fmt: off
+    @tvm.script.ir_module
+    class StridedSlice:
+        @R.function
+        def main(x: R.Tensor((10, "n"), "float32")) -> R.Tensor((3, "n"), 
"float32"):
+            n = T.int64(is_size_var=True)
+            gv: R.Tensor((3, n), "float32") = R.strided_slice(x, axes=[0, 1], 
begin=[1, 0], end=[8, n], strides=[3, 1])
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((10, "n"), dtype="float32")) -> R.Tensor((3, 
"n"), dtype="float32"):
+            n = T.int64(is_size_var=True)
+            gv = R.call_tir(Expected.strided_slice, (x,), R.Tensor((3, n), 
dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def strided_slice(var_rxplaceholder: T.handle, 
var_T_strided_slice_with_axes: T.handle):
+            T.func_attr({"tir.noalias": True})
+            n = T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(10), 
n], dtype="float32")
+            T_strided_slice_with_axes = 
T.match_buffer(var_T_strided_slice_with_axes, [T.int64(3), n], dtype="float32")
+            for i0, i1 in T.grid(T.int64(3), n):
+                with T.block("T_strided_slice_with_axes"):
+                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[ax0 * T.int64(3) + T.int64(1), ax1])
+                    T.writes(T_strided_slice_with_axes[ax0, ax1])
+                    T_strided_slice_with_axes[ax0, ax1] = rxplaceholder[ax0 * 
T.int64(3) + T.int64(1), ax1]
+
+
+def test_strided_slice_non_unit_stride():
+    # fmt: off
+    @tvm.script.ir_module
+    class StridedSlice:
+        @R.function
+        def main(x: R.Tensor((10, "n"), "float32")) -> R.Tensor((3, "n"), 
"float32"):
+            n = T.int64(is_size_var=True)
+            gv: R.Tensor((3, n), "float32") = R.strided_slice(x, axes=[0, 1], 
begin=[1, 0], end=[8, n], strides=[3, 1])
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((10, "n"), dtype="float32")) -> R.Tensor((3, 
"n"), dtype="float32"):
+            n = T.int64(is_size_var=True)
+            gv = R.call_tir(Expected.strided_slice, (x,), R.Tensor((3, n), 
dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def strided_slice(var_rxplaceholder: T.handle, 
var_T_strided_slice_with_axes: T.handle):
+            T.func_attr({"tir.noalias": True})
+            n = T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(10), 
n], dtype="float32")
+            T_strided_slice_with_axes = 
T.match_buffer(var_T_strided_slice_with_axes, [T.int64(3), n], dtype="float32")
+            for i0, i1 in T.grid(T.int64(3), n):
+                with T.block("T_strided_slice_with_axes"):
+                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[ax0 * T.int64(3) + T.int64(1), ax1])
+                    T.writes(T_strided_slice_with_axes[ax0, ax1])
+                    T_strided_slice_with_axes[ax0, ax1] = rxplaceholder[ax0 * 
T.int64(3) + T.int64(1), ax1]
+
+
 def test_dynamic_strided_slice():
     # fmt: off
     @tvm.script.ir_module

Reply via email to