This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0d5bde59de [FIX][TOPI][strided_slice] Fix topi.strided_slice output
shape (#17502)
0d5bde59de is described below
commit 0d5bde59def02bf700339cda450e774b6d53407d
Author: PatrikPerssonInceptron
<[email protected]>
AuthorDate: Tue Nov 12 08:48:57 2024 +0100
[FIX][TOPI][strided_slice] Fix topi.strided_slice output shape (#17502)
* updated topi.strided_slice to perform the same canonicalize index as
relax.strided_slice given a assume_inbound flag
* applied formatting
* removed debug statements
* set assume_inbound=True in leg_redistribute_replica_to_shard as this is
assumed
in the associated unit tests
moved misplaced function description
fixed shape error and added a assume_inbound=True to
test_strided_slice_symbolic_sliced_axis
added param description for assume_inbound
* added doc for param assume_inbound in dynamic_strided_slice
* added # fmt: off and # fmt: on
---
include/tvm/topi/transform.h | 68 ++++++++++++++++++++--
.../relax/transform/legalize_ops/distributed.py | 1 +
python/tvm/relax/transform/legalize_ops/index.py | 1 +
python/tvm/topi/transform.py | 7 ++-
src/relax/op/tensor/index.cc | 26 +--------
src/topi/transform.cc | 9 ++-
tests/python/relax/test_op_index.py | 40 +++++++++++++
..._transform_legalize_ops_index_linear_algebra.py | 2 +-
8 files changed, 121 insertions(+), 33 deletions(-)
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 3292ce57ba..bada827c81 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -43,7 +43,11 @@
#include <unordered_set>
#include <vector>
+#include "tvm/ir/expr.h"
+#include "tvm/runtime/data_type.h"
#include "tvm/tir/expr.h"
+#include "tvm/tir/op.h"
+#include "tvm/tir/var.h"
namespace tvm {
namespace topi {
@@ -635,6 +639,55 @@ inline Array<Tensor> split(const Tensor& x,
Array<PrimExpr> split_indices, int a
return result;
}
+inline PrimExpr DynamicCanonicalizeIndex(PrimExpr index, PrimExpr extent,
PrimExpr stride) {
+ auto idx_var = index.as<tvm::tir::VarNode>();
+ auto extent_var = extent.as<tvm::tir::VarNode>();
+
+ if (idx_var && extent_var && idx_var->name_hint == extent_var->name_hint) {
+ return index;
+ }
+
+ PrimExpr begin_range = tvm::if_then_else(stride < 0, -1, 0);
+ PrimExpr end_range = tvm::if_then_else(stride < 0, extent - 1, extent);
+
+ if (!(index->IsInstance<tvm::IntImmNode>() && GetConstInt(index) >= 0)) {
+ index = tvm::if_then_else(index < 0, index + extent, index);
+ }
+
+ return tvm::min(tvm::max(index, begin_range), end_range);
+}
+
+inline int64_t StaticCanonicalizeIndex(int64_t index, int64_t extent, int64_t
stride) {
+ int64_t begin_range = stride < 0 ? -1 : 0;
+ int64_t end_range = stride < 0 ? extent - 1 : extent;
+ if (index < 0) {
+ index += extent;
+ }
+ return std::min(std::max(index, begin_range), end_range);
+}
+
+inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr
stride) {
+ if (index->IsInstance<tvm::IntImmNode>() &&
extent->IsInstance<tvm::IntImmNode>() &&
+ stride->IsInstance<tvm::IntImmNode>()) {
+ return tvm::IntImm(
+ tvm::DataType::Int(64),
+ StaticCanonicalizeIndex(GetConstInt(index), GetConstInt(extent),
GetConstInt(stride)));
+ }
+ return DynamicCanonicalizeIndex(index, extent, stride);
+}
+
+inline PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride,
PrimExpr extent,
+ bool assume_inbound = true) {
+ if (assume_inbound) {
+ return ceildiv(end - begin, stride);
+ } else {
+ begin = CanonicalizeIndex(begin, extent, stride);
+ end = CanonicalizeIndex(end, extent, stride);
+ return tvm::if_then_else(stride < 0, ceildiv(begin - end, -stride),
+ ceildiv(end - begin, stride));
+ }
+}
+
/*!
* \brief strided_slice of a tensor where begin/end/stride can be mixed static
and dynamic
*
@@ -644,6 +697,7 @@ inline Array<Tensor> split(const Tensor& x, Array<PrimExpr>
split_indices, int a
* \param strides Specifies the stride values, it can be negative
* in that case, the input tensor will be reversed in that particular axis
* \param axes Specifies which axes will be updated.
+ * \param assume_inbound Specifies if all indices are assumed to be inbound
* \param name The name of the operation
* \param tag The tag to mark the operation
*
@@ -651,7 +705,7 @@ inline Array<Tensor> split(const Tensor& x, Array<PrimExpr>
split_indices, int a
*/
inline Tensor dynamic_strided_slice_with_axes(
const Tensor& x, const Array<PrimExpr>& begin, const Array<PrimExpr>& end,
- const Array<PrimExpr>& strides, const Array<Integer>& axes,
+ const Array<PrimExpr>& strides, const Array<Integer>& axes, bool
assume_inbound = true,
std::string name = "T_dynamic_strided_slice_with_axes", std::string tag =
kInjective) {
const size_t src_tensor_dim = x->shape.size();
ICHECK_EQ(begin.size(), end.size());
@@ -669,7 +723,8 @@ inline Tensor dynamic_strided_slice_with_axes(
Array<PrimExpr> out_shape = x->shape;
for (size_t i = 0; i < begin.size(); i++) {
int axis = axes[i]->value;
- PrimExpr new_shape = analyzer.Simplify(ceildiv(end[i] - begin[i],
strides[i]));
+ PrimExpr new_shape =
+ analyzer.Simplify(GetLength(begin[i], end[i], strides[i],
out_shape[axis], assume_inbound));
out_shape.Set(axis, new_shape);
}
@@ -697,6 +752,7 @@ inline Tensor dynamic_strided_slice_with_axes(
* \param end Indices indicating end of the slice
* \param strides Specifies the stride values, it can be negative
* in that case, the input tensor will be reversed in that particular axis
+ * \param assume_inbound Specifies if all indices are assumed to be inbound
* \param name The name of the operation
* \param tag The tag to mark the operation
*
@@ -704,6 +760,7 @@ inline Tensor dynamic_strided_slice_with_axes(
*/
inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>&
begin,
const Array<PrimExpr>& end, const
Array<PrimExpr>& strides,
+ bool assume_inbound = true,
std::string name =
"T_dynamic_strided_slice",
std::string tag = kInjective) {
const size_t src_tensor_dim = x->shape.size();
@@ -721,7 +778,8 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const
Array<PrimExpr>& begi
// Check ProducerLoad to keep backward compatibility for Relay.
if (!begin[i]->IsInstance<ProducerLoadNode>() &&
!end[i]->IsInstance<ProducerLoadNode>() &&
!strides[i]->IsInstance<ProducerLoadNode>()) {
- out_shape.push_back(analyzer.Simplify(ceildiv(end[i] - begin[i],
strides[i])));
+ out_shape.push_back(
+ analyzer.Simplify(GetLength(begin[i], end[i], strides[i],
x->shape[i], assume_inbound)));
} else {
out_shape.push_back(tvm::tir::Var("dim"));
}
@@ -755,6 +813,7 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const
Array<PrimExpr>& begi
* \param end Indices indicating end of the slice
* \param strides Specifies the stride values, it can be negative
* in that case, the input tensor will be reversed in that particular axis
+ * \param assume_inbound Specifies if all indices are assumed to be inbound
* \param name The name of the operation
* \param tag The tag to mark the operation
*
@@ -762,6 +821,7 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const
Array<PrimExpr>& begi
*/
inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor&
begin,
const te::Tensor& end, const
te::Tensor& strides,
+ bool assume_inbound = true,
std::string name =
"T_strided_slice_dynamic",
std::string tag = topi::kInjective) {
DataType index_dtype = begin->shape[0]->dtype;
@@ -776,7 +836,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor&
x, const te::Tensor& b
end_expr.push_back(end(ind));
strides_expr.push_back(strides(ind));
}
- return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name,
tag);
+ return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr,
assume_inbound, name, tag);
}
/*!
diff --git a/python/tvm/relax/transform/legalize_ops/distributed.py
b/python/tvm/relax/transform/legalize_ops/distributed.py
index d540628e0e..2ca283d1ec 100644
--- a/python/tvm/relax/transform/legalize_ops/distributed.py
+++ b/python/tvm/relax/transform/legalize_ops/distributed.py
@@ -40,4 +40,5 @@ def _redistribute_replica_to_shard(_bb: BlockBuilder, call:
Call) -> Expr:
axes=[axis],
begin=[worker_id_symbol * split_axis_size // num_workers],
end=[(worker_id_symbol + 1) * split_axis_size // num_workers],
+ assume_inbound=True,
)
diff --git a/python/tvm/relax/transform/legalize_ops/index.py
b/python/tvm/relax/transform/legalize_ops/index.py
index a4fac46a13..8d0ac535f6 100644
--- a/python/tvm/relax/transform/legalize_ops/index.py
+++ b/python/tvm/relax/transform/legalize_ops/index.py
@@ -67,6 +67,7 @@ def _strided_slice(bb: BlockBuilder, call: Call) -> Expr:
strides,
axes,
slice_mode="end",
+ assume_inbound=call.attrs.assume_inbound,
)
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index 3b007a6325..686311fbee 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -170,7 +170,7 @@ def reverse_sequence(a, seq_lengths, seq_axis=1,
batch_axis=0):
return cpp.reverse_sequence(a, seq_lengths, seq_axis, batch_axis)
-def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end"):
+def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end",
assume_inbound=True):
"""Slice of an array.
Parameters
@@ -200,6 +200,9 @@ def strided_slice(a, begin, end, strides=None, axes=None,
slice_mode="end"):
the sizeof a slice starting at the location specified by begin. If
end[i]
is -1, all remaining elements in that dimension are included in the
slice.
+ assume_inbound: bool, optional
+ A flag to indicate if all indices are assumed to be inbound
+
Returns
-------
ret : tvm.te.Tensor
@@ -223,7 +226,7 @@ def strided_slice(a, begin, end, strides=None, axes=None,
slice_mode="end"):
strides = []
if axes is None:
axes = []
- return cpp.strided_slice(a, begin, end, strides, axes, slice_mode)
+ return cpp.strided_slice(a, begin, end, strides, axes, slice_mode,
assume_inbound)
def dynamic_strided_slice(a, begin, end, strides, output_shape):
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index 36527c3584..e62dbe89d0 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -25,6 +25,7 @@
#include "index.h"
#include <tvm/relax/analysis.h>
+#include <tvm/topi/transform.h>
#include <algorithm>
#include <optional>
@@ -171,29 +172,6 @@ Expr strided_slice(Expr x, Expr axes, Expr begin, Expr
end, Optional<Expr> strid
TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice);
-inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, PrimExpr
stride) {
- // Handle Python-style negative indices
- index = if_then_else(index < 0, index + extent, index);
- // Clamp the result to valid indices
- PrimExpr lower_bound = tvm::if_then_else(stride < 0, -1, 0);
- PrimExpr upper_bound = tvm::if_then_else(stride < 0, extent - 1, extent);
- index = tvm::min(tvm::max(index, lower_bound), upper_bound);
-
- return index;
-}
-
-PrimExpr GetLength(PrimExpr begin, PrimExpr end, PrimExpr stride, PrimExpr
extent,
- bool assume_inbound) {
- if (assume_inbound) {
- return ceildiv(end - begin, stride);
- } else {
- begin = CanonicalizeIndex(begin, extent, stride);
- end = CanonicalizeIndex(end, extent, stride);
- return tvm::if_then_else(stride < 0, ceildiv(begin - end, -stride),
- ceildiv(end - begin, stride));
- }
-}
-
/* \brief Helper function to unpack a relax::Tuple
*
* A `relax::Tuple` may be provided to an operator as an in-line
@@ -424,7 +402,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call,
const BlockBuilder& ctx
PrimExpr end = end_tuple[i];
PrimExpr output_dim =
- GetLength(begin, end, strides_tuple[i], input_dim,
attrs->assume_inbound);
+ topi::GetLength(begin, end, strides_tuple[i], input_dim,
attrs->assume_inbound);
arith::Analyzer* analyzer = ctx->GetAnalyzer();
std::optional<With<arith::ConstraintContext>> context;
diff --git a/src/topi/transform.cc b/src/topi/transform.cc
index a84e3dce50..d844739568 100644
--- a/src/topi/transform.cc
+++ b/src/topi/transform.cc
@@ -27,6 +27,10 @@
#include <tvm/topi/transform.h>
#include <tvm/topi/utils.h>
+#include <iostream>
+
+#include "tvm/ir/expr.h"
+
namespace tvm {
namespace topi {
@@ -179,6 +183,7 @@
TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue*
Array<PrimExpr> end = args[2];
Array<PrimExpr> strides = args[3];
Array<Integer> axes = args[4];
+ bool assume_inbound = args[6];
if (IsConstIntArray(begin) && IsConstIntArray(end) &&
IsConstIntArray(strides) &&
IsConstIntArray(x->shape)) {
Array<Integer> begin_static = args[1];
@@ -192,9 +197,9 @@
TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue*
}
} else {
if (axes.size()) {
- *rv = dynamic_strided_slice_with_axes(x, begin, end, strides, axes);
+ *rv = dynamic_strided_slice_with_axes(x, begin, end, strides, axes,
assume_inbound);
} else {
- *rv = dynamic_strided_slice(x, begin, end, strides);
+ *rv = dynamic_strided_slice(x, begin, end, strides, assume_inbound);
}
}
});
diff --git a/tests/python/relax/test_op_index.py
b/tests/python/relax/test_op_index.py
index 57e7a14b70..31245de599 100644
--- a/tests/python/relax/test_op_index.py
+++ b/tests/python/relax/test_op_index.py
@@ -21,6 +21,7 @@ from tvm import relax, tir
from tvm import TVMError
from tvm.ir import Op, VDevice
from tvm.script import ir as I, relax as R, tir as T
+import numpy as np
def test_op_correctness():
@@ -1010,5 +1011,44 @@ def test_legalize_dynamic_begin_end():
tvm.ir.assert_structural_equal(expected, after)
+def test_legalize_dynamic_begin_inf_end():
+ """relax.op.strided_slice FLegalize must support dynamic begin/end"""
+
+ @I.ir_module
+ class before:
+ @R.function
+ def main(A: R.Tensor((16, 16), "float32"), B: R.Shape(["index"])) ->
R.Tensor((1, 16)):
+ index = T.int64()
+ return R.strided_slice(
+ A, [0], [index], [T.int64(np.iinfo(np.int64).max)],
assume_inbound=False
+ )
+
+ # fmt: off
+ @I.ir_module
+ class expected:
+ @T.prim_func(private=True)
+ def strided_slice(A: T.Buffer((T.int64(16), T.int64(16)), "float32"),
var_T_dynamic_strided_slice_with_axes: T.handle, index: T.int64):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ T_dynamic_strided_slice_with_axes =
T.match_buffer(var_T_dynamic_strided_slice_with_axes, (T.max(T.int64(16) -
T.max(T.if_then_else(index < T.int64(0), index + T.int64(16), index),
T.int64(0)), T.int64(0)), T.int64(16)))
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(T.max(T.int64(16) -
T.max(T.if_then_else(index < T.int64(0), index + T.int64(16), index),
T.int64(0)), T.int64(0)), T.int64(16)):
+ with T.block("T_dynamic_strided_slice_with_axes"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[v_ax0 + index, v_ax1])
+ T.writes(T_dynamic_strided_slice_with_axes[v_ax0, v_ax1])
+ T_dynamic_strided_slice_with_axes[v_ax0, v_ax1] = A[v_ax0
+ index, v_ax1]
+
+ @R.function
+ def main(A: R.Tensor((16, 16), dtype="float32"), B:
R.Shape(["index"])) -> R.Tensor(("T.max(16 - T.max(T.if_then_else(index < 0,
index + 16, index), 0), 0)", 16), dtype="float32"):
+ index = T.int64()
+ cls = expected
+ gv = R.call_tir(cls.strided_slice, (A,),
out_sinfo=R.Tensor((T.max(16 - T.max(T.if_then_else(index < 0, index + 16,
index), 0), 0), 16), dtype="float32"), tir_vars=R.shape([index]))
+ return gv
+ # fmt: on
+
+ after = tvm.relax.transform.LegalizeOps()(before)
+ tvm.ir.assert_structural_equal(expected, after)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index 2f4da5cf06..90643694c1 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -263,7 +263,7 @@ def test_strided_slice_symbolic_sliced_axis():
@R.function
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor((2, "n"),
"float32"):
n = T.int64()
- gv: R.Tensor((2, n), "float32") = R.strided_slice(x, axes=[0],
begin=[1], end=[8], strides=[3])
+ gv: R.Tensor((3, n), "float32") = R.strided_slice(x, axes=[0],
begin=[1], end=[8], strides=[3], assume_inbound=True)
return gv
@I.ir_module