This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0d5bde59de [FIX][TOPI][strided_slice] Fix topi.strided_slice output 
shape (#17502)
0d5bde59de is described below

commit 0d5bde59def02bf700339cda450e774b6d53407d
Author: PatrikPerssonInceptron 
<[email protected]>
AuthorDate: Tue Nov 12 08:48:57 2024 +0100

    [FIX][TOPI][strided_slice] Fix topi.strided_slice output shape (#17502)
    
    * updated topi.strided_slice to perform the same canonicalize index as
    relax.strided_slice given a assume_inbound flag
    
    * applied formatting
    
    * removed debug statements
    
    * set assume_inbound=True in leg_redistribute_replica_to_shard as this is 
assumed
    in the associated unit tests
    
    moved misplaced function description
    
    fixed shape error and added a assume_inbound=True to 
test_strided_slice_symbolic_sliced_axis
    
    added param description for assume_inbound
    
    * added doc for param assume_inbound in dynamic_strided_slice
    
    * added # fmt: off and # fmt: on
---
 include/tvm/topi/transform.h                       | 68 ++++++++++++++++++++--
 .../relax/transform/legalize_ops/distributed.py    |  1 +
 python/tvm/relax/transform/legalize_ops/index.py   |  1 +
 python/tvm/topi/transform.py                       |  7 ++-
 src/relax/op/tensor/index.cc                       | 26 +--------
 src/topi/transform.cc                              |  9 ++-
 tests/python/relax/test_op_index.py                | 40 +++++++++++++
 ..._transform_legalize_ops_index_linear_algebra.py |  2 +-
 8 files changed, 121 insertions(+), 33 deletions(-)

diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 3292ce57ba..bada827c81 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -43,7 +43,11 @@
 #include <unordered_set>
 #include <vector>
 
+#include "tvm/ir/expr.h"
+#include "tvm/runtime/data_type.h"
 #include "tvm/tir/expr.h"
+#include "tvm/tir/op.h"
+#include "tvm/tir/var.h"
 
 namespace tvm {
 namespace topi {
@@ -635,6 +639,55 @@ inline Array<Tensor> split(const Tensor& x, 
Array<PrimExpr> split_indices, int a
   return result;
 }
 
+inline PrimExpr DynamicCanonicalizeIndex(PrimExpr index, PrimExpr extent, 
PrimExpr stride) {
+  auto idx_var = index.as<tvm::tir::VarNode>();
+  auto extent_var = extent.as<tvm::tir::VarNode>();
+
+  if (idx_var && extent_var && idx_var->name_hint == extent_var->name_hint) {
+    return index;
+  }
+
+  PrimExpr begin_range = tvm::if_then_else(stride < 0, -1, 0);
+  PrimExpr end_range = tvm::if_then_else(stride < 0, extent - 1, extent);
+
+  if (!(index->IsInstance<tvm::IntImmNode>() && GetConstInt(index) >= 0)) {
+    index = tvm::if_then_else(index < 0, index + extent, index);
+  }
+
+  return tvm::min(tvm::max(index, begin_range), end_range);
+}
+
+inline int64_t StaticCanonicalizeIndex(int64_t index, int64_t extent, int64_t 
stride) {
+  int64_t begin_range = stride < 0 ? -1 : 0;
+  int64_t end_range = stride < 0 ? extent - 1 : extent;
+  if (index < 0) {
+    index += extent;
+  }
+  return std::min(std::max(index, begin_range), end_range);
+}
+
+inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr 
stride) {
+  if (index->IsInstance<tvm::IntImmNode>() && 
extent->IsInstance<tvm::IntImmNode>() &&
+      stride->IsInstance<tvm::IntImmNode>()) {
+    return tvm::IntImm(
+        tvm::DataType::Int(64),
+        StaticCanonicalizeIndex(GetConstInt(index), GetConstInt(extent), 
GetConstInt(stride)));
+  }
+  return DynamicCanonicalizeIndex(index, extent, stride);
+}
+
+inline PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, 
PrimExpr extent,
+                          bool assume_inbound = true) {
+  if (assume_inbound) {
+    return ceildiv(end - begin, stride);
+  } else {
+    begin = CanonicalizeIndex(begin, extent, stride);
+    end = CanonicalizeIndex(end, extent, stride);
+    return tvm::if_then_else(stride < 0, ceildiv(begin - end, -stride),
+                             ceildiv(end - begin, stride));
+  }
+}
+
 /*!
  * \brief strided_slice of a tensor where begin/end/stride can be mixed static 
and dynamic
  *
@@ -644,6 +697,7 @@ inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> 
split_indices, int a
  * \param strides Specifies the stride values, it can be negative
  * in that case, the input tensor will be reversed in that particular axis
  * \param axes Specifies which axes will be updated.
+ * \param assume_inbound Specifies if all indices are assumed to be inbound
  * \param name The name of the operation
  * \param tag The tag to mark the operation
  *
@@ -651,7 +705,7 @@ inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> 
split_indices, int a
  */
 inline Tensor dynamic_strided_slice_with_axes(
     const Tensor& x, const Array<PrimExpr>& begin, const Array<PrimExpr>& end,
-    const Array<PrimExpr>& strides, const Array<Integer>& axes,
+    const Array<PrimExpr>& strides, const Array<Integer>& axes, bool 
assume_inbound = true,
     std::string name = "T_dynamic_strided_slice_with_axes", std::string tag = 
kInjective) {
   const size_t src_tensor_dim = x->shape.size();
   ICHECK_EQ(begin.size(), end.size());
@@ -669,7 +723,8 @@ inline Tensor dynamic_strided_slice_with_axes(
   Array<PrimExpr> out_shape = x->shape;
   for (size_t i = 0; i < begin.size(); i++) {
     int axis = axes[i]->value;
-    PrimExpr new_shape = analyzer.Simplify(ceildiv(end[i] - begin[i], 
strides[i]));
+    PrimExpr new_shape =
+        analyzer.Simplify(GetLength(begin[i], end[i], strides[i], 
out_shape[axis], assume_inbound));
     out_shape.Set(axis, new_shape);
   }
 
@@ -697,6 +752,7 @@ inline Tensor dynamic_strided_slice_with_axes(
  * \param end Indices indicating end of the slice
  * \param strides Specifies the stride values, it can be negative
  * in that case, the input tensor will be reversed in that particular axis
+ * \param assume_inbound Specifies if all indices are assumed to be inbound
  * \param name The name of the operation
  * \param tag The tag to mark the operation
  *
@@ -704,6 +760,7 @@ inline Tensor dynamic_strided_slice_with_axes(
  */
 inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>& 
begin,
                                     const Array<PrimExpr>& end, const 
Array<PrimExpr>& strides,
+                                    bool assume_inbound = true,
                                     std::string name = 
"T_dynamic_strided_slice",
                                     std::string tag = kInjective) {
   const size_t src_tensor_dim = x->shape.size();
@@ -721,7 +778,8 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const 
Array<PrimExpr>& begi
     // Check ProducerLoad to keep backward compatibility for Relay.
     if (!begin[i]->IsInstance<ProducerLoadNode>() && 
!end[i]->IsInstance<ProducerLoadNode>() &&
         !strides[i]->IsInstance<ProducerLoadNode>()) {
-      out_shape.push_back(analyzer.Simplify(ceildiv(end[i] - begin[i], 
strides[i])));
+      out_shape.push_back(
+          analyzer.Simplify(GetLength(begin[i], end[i], strides[i], 
x->shape[i], assume_inbound)));
     } else {
       out_shape.push_back(tvm::tir::Var("dim"));
     }
@@ -755,6 +813,7 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const 
Array<PrimExpr>& begi
  * \param end Indices indicating end of the slice
  * \param strides Specifies the stride values, it can be negative
  * in that case, the input tensor will be reversed in that particular axis
+ * \param assume_inbound Specifies if all indices are assumed to be inbound
  * \param name The name of the operation
  * \param tag The tag to mark the operation
  *
@@ -762,6 +821,7 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const 
Array<PrimExpr>& begi
  */
 inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& 
begin,
                                         const te::Tensor& end, const 
te::Tensor& strides,
+                                        bool assume_inbound = true,
                                         std::string name = 
"T_strided_slice_dynamic",
                                         std::string tag = topi::kInjective) {
   DataType index_dtype = begin->shape[0]->dtype;
@@ -776,7 +836,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& 
x, const te::Tensor& b
     end_expr.push_back(end(ind));
     strides_expr.push_back(strides(ind));
   }
-  return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, 
tag);
+  return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, 
assume_inbound, name, tag);
 }
 
 /*!
diff --git a/python/tvm/relax/transform/legalize_ops/distributed.py 
b/python/tvm/relax/transform/legalize_ops/distributed.py
index d540628e0e..2ca283d1ec 100644
--- a/python/tvm/relax/transform/legalize_ops/distributed.py
+++ b/python/tvm/relax/transform/legalize_ops/distributed.py
@@ -40,4 +40,5 @@ def _redistribute_replica_to_shard(_bb: BlockBuilder, call: 
Call) -> Expr:
         axes=[axis],
         begin=[worker_id_symbol * split_axis_size // num_workers],
         end=[(worker_id_symbol + 1) * split_axis_size // num_workers],
+        assume_inbound=True,
     )
diff --git a/python/tvm/relax/transform/legalize_ops/index.py 
b/python/tvm/relax/transform/legalize_ops/index.py
index a4fac46a13..8d0ac535f6 100644
--- a/python/tvm/relax/transform/legalize_ops/index.py
+++ b/python/tvm/relax/transform/legalize_ops/index.py
@@ -67,6 +67,7 @@ def _strided_slice(bb: BlockBuilder, call: Call) -> Expr:
         strides,
         axes,
         slice_mode="end",
+        assume_inbound=call.attrs.assume_inbound,
     )
 
 
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index 3b007a6325..686311fbee 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -170,7 +170,7 @@ def reverse_sequence(a, seq_lengths, seq_axis=1, 
batch_axis=0):
     return cpp.reverse_sequence(a, seq_lengths, seq_axis, batch_axis)
 
 
-def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end"):
+def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end", 
assume_inbound=True):
     """Slice of an array.
 
     Parameters
@@ -200,6 +200,9 @@ def strided_slice(a, begin, end, strides=None, axes=None, 
slice_mode="end"):
         the sizeof a slice starting at the location specified by begin. If 
end[i]
         is -1, all remaining elements in that dimension are included in the 
slice.
 
+    assume_inbound: bool, optional
+        A flag to indicate if all indices are assumed to be inbound
+
     Returns
     -------
     ret : tvm.te.Tensor
@@ -223,7 +226,7 @@ def strided_slice(a, begin, end, strides=None, axes=None, 
slice_mode="end"):
         strides = []
     if axes is None:
         axes = []
-    return cpp.strided_slice(a, begin, end, strides, axes, slice_mode)
+    return cpp.strided_slice(a, begin, end, strides, axes, slice_mode, 
assume_inbound)
 
 
 def dynamic_strided_slice(a, begin, end, strides, output_shape):
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index 36527c3584..e62dbe89d0 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -25,6 +25,7 @@
 #include "index.h"
 
 #include <tvm/relax/analysis.h>
+#include <tvm/topi/transform.h>
 
 #include <algorithm>
 #include <optional>
@@ -171,29 +172,6 @@ Expr strided_slice(Expr x, Expr axes, Expr begin, Expr 
end, Optional<Expr> strid
 
 TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice);
 
-inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr 
stride) {
-  // Handle Python-style negative indices
-  index = if_then_else(index < 0, index + extent, index);
-  // Clamp the result to valid indices
-  PrimExpr lower_bound = tvm::if_then_else(stride < 0, -1, 0);
-  PrimExpr upper_bound = tvm::if_then_else(stride < 0, extent - 1, extent);
-  index = tvm::min(tvm::max(index, lower_bound), upper_bound);
-
-  return index;
-}
-
-PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExpr 
extent,
-                   bool assume_inbound) {
-  if (assume_inbound) {
-    return ceildiv(end - begin, stride);
-  } else {
-    begin = CanonicalizeIndex(begin, extent, stride);
-    end = CanonicalizeIndex(end, extent, stride);
-    return tvm::if_then_else(stride < 0, ceildiv(begin - end, -stride),
-                             ceildiv(end - begin, stride));
-  }
-}
-
 /* \brief Helper function to unpack a relax::Tuple
  *
  * A `relax::Tuple` may be provided to an operator as an in-line
@@ -424,7 +402,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, 
const BlockBuilder& ctx
       PrimExpr end = end_tuple[i];
 
       PrimExpr output_dim =
-          GetLength(begin, end, strides_tuple[i], input_dim, 
attrs->assume_inbound);
+          topi::GetLength(begin, end, strides_tuple[i], input_dim, 
attrs->assume_inbound);
 
       arith::Analyzer* analyzer = ctx->GetAnalyzer();
       std::optional<With<arith::ConstraintContext>> context;
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index a84e3dce50..d844739568 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -27,6 +27,10 @@
 #include <tvm/topi/transform.h>
 #include <tvm/topi/utils.h>
 
+#include <iostream>
+
+#include "tvm/ir/expr.h"
+
 namespace tvm {
 namespace topi {
 
@@ -179,6 +183,7 @@ 
TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue*
   Array<PrimExpr> end = args[2];
   Array<PrimExpr> strides = args[3];
   Array<Integer> axes = args[4];
+  bool assume_inbound = args[6];
   if (IsConstIntArray(begin) && IsConstIntArray(end) && 
IsConstIntArray(strides) &&
       IsConstIntArray(x->shape)) {
     Array<Integer> begin_static = args[1];
@@ -192,9 +197,9 @@ 
TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue*
     }
   } else {
     if (axes.size()) {
-      *rv = dynamic_strided_slice_with_axes(x, begin, end, strides, axes);
+      *rv = dynamic_strided_slice_with_axes(x, begin, end, strides, axes, 
assume_inbound);
     } else {
-      *rv = dynamic_strided_slice(x, begin, end, strides);
+      *rv = dynamic_strided_slice(x, begin, end, strides, assume_inbound);
     }
   }
 });
diff --git a/tests/python/relax/test_op_index.py 
b/tests/python/relax/test_op_index.py
index 57e7a14b70..31245de599 100644
--- a/tests/python/relax/test_op_index.py
+++ b/tests/python/relax/test_op_index.py
@@ -21,6 +21,7 @@ from tvm import relax, tir
 from tvm import TVMError
 from tvm.ir import Op, VDevice
 from tvm.script import ir as I, relax as R, tir as T
+import numpy as np
 
 
 def test_op_correctness():
@@ -1010,5 +1011,44 @@ def test_legalize_dynamic_begin_end():
     tvm.ir.assert_structural_equal(expected, after)
 
 
+def test_legalize_dynamic_begin_inf_end():
+    """relax.op.strided_slice FLegalize must support dynamic begin/end"""
+
+    @I.ir_module
+    class before:
+        @R.function
+        def main(A: R.Tensor((16, 16), "float32"), B: R.Shape(["index"])) -> 
R.Tensor((1, 16)):
+            index = T.int64()
+            return R.strided_slice(
+                A, [0], [index], [T.int64(np.iinfo(np.int64).max)], 
assume_inbound=False
+            )
+
+    # fmt: off
+    @I.ir_module
+    class expected:
+        @T.prim_func(private=True)
+        def strided_slice(A: T.Buffer((T.int64(16), T.int64(16)), "float32"), 
var_T_dynamic_strided_slice_with_axes: T.handle, index: T.int64):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            T_dynamic_strided_slice_with_axes = 
T.match_buffer(var_T_dynamic_strided_slice_with_axes, (T.max(T.int64(16) - 
T.max(T.if_then_else(index < T.int64(0), index + T.int64(16), index), 
T.int64(0)), T.int64(0)), T.int64(16)))
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(T.max(T.int64(16) - 
T.max(T.if_then_else(index < T.int64(0), index + T.int64(16), index), 
T.int64(0)), T.int64(0)), T.int64(16)):
+                with T.block("T_dynamic_strided_slice_with_axes"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0 + index, v_ax1])
+                    T.writes(T_dynamic_strided_slice_with_axes[v_ax0, v_ax1])
+                    T_dynamic_strided_slice_with_axes[v_ax0, v_ax1] = A[v_ax0 
+ index, v_ax1]
+
+        @R.function
+        def main(A: R.Tensor((16, 16), dtype="float32"), B: 
R.Shape(["index"])) -> R.Tensor(("T.max(16 - T.max(T.if_then_else(index < 0, 
index + 16, index), 0), 0)", 16), dtype="float32"):
+            index = T.int64()
+            cls = expected
+            gv = R.call_tir(cls.strided_slice, (A,), 
out_sinfo=R.Tensor((T.max(16 - T.max(T.if_then_else(index < 0, index + 16, 
index), 0), 0), 16), dtype="float32"), tir_vars=R.shape([index]))
+            return gv
+    # fmt: on
+
+    after = tvm.relax.transform.LegalizeOps()(before)
+    tvm.ir.assert_structural_equal(expected, after)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git 
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py 
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index 2f4da5cf06..90643694c1 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -263,7 +263,7 @@ def test_strided_slice_symbolic_sliced_axis():
         @R.function
         def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor((2, "n"), 
"float32"):
             n = T.int64()
-            gv: R.Tensor((2, n), "float32") = R.strided_slice(x, axes=[0], 
begin=[1], end=[8], strides=[3])
+            gv: R.Tensor((3, n), "float32") = R.strided_slice(x, axes=[0], 
begin=[1], end=[8], strides=[3], assume_inbound=True)
             return gv
 
     @I.ir_module

Reply via email to