This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 79368cea1a [Relax][ONNX][Transform] Add mode choice, new mode, and
warning for take() (#18061)
79368cea1a is described below
commit 79368cea1afab6268b2bbbff96ea3d63a524130e
Author: Youngsik Yang <[email protected]>
AuthorDate: Tue Jul 8 20:58:35 2025 +0900
[Relax][ONNX][Transform] Add mode choice, new mode, and warning for take()
(#18061)
[Relax][Transform] Add mode choice, NaN mode, and warning for take()
- Add a `mode` parameter to Relax’s `take()`
- Add `NaN` mode to `take()`
- Add unit tests covering all `take()` modes
- Add a warning log for `fast` mode
- Unify default modes in lower layers to `fast` for consistency with Relax
---
include/tvm/relax/attrs/index.h | 7 ++-
include/tvm/topi/transform.h | 39 +++++++++++-
python/tvm/relax/op/index.py | 10 ++-
python/tvm/relax/transform/legalize_ops/index.py | 8 +--
python/tvm/topi/transform.py | 11 ++--
src/relax/op/tensor/index.cc | 9 ++-
src/relax/op/tensor/index.h | 3 +-
tests/python/relax/test_op_take.py | 80 ++++++++++++++++++++++++
8 files changed, 147 insertions(+), 20 deletions(-)
diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h
index 6f283f7044..cc914449db 100644
--- a/include/tvm/relax/attrs/index.h
+++ b/include/tvm/relax/attrs/index.h
@@ -32,11 +32,14 @@ namespace relax {
/*! \brief Attributes used in take operator */
struct TakeAttrs : public AttrsNodeReflAdapter<TakeAttrs> {
Optional<int64_t> axis;
+ String mode;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
- refl::ObjectDef<TakeAttrs>().def_ro("axis", &TakeAttrs::axis,
- "The axis over which to select
values.");
+ refl::ObjectDef<TakeAttrs>()
+ .def_ro("axis", &TakeAttrs::axis, "The axis over which to select
values.")
+ .def_ro("mode", &TakeAttrs::mode, "The mode for handling out-of-bounds
indices.",
+ refl::DefaultValue("fast"));
}
static constexpr const char* _type_key = "relax.attrs.TakeAttrs";
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index ac32ec6a42..df637f6f58 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -1008,7 +1008,7 @@ inline Array<Tensor> split_n_sections(const Tensor& x,
int num_sections, int axi
* \return A Tensor whose op member is the take operation
*/
inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims,
- std::string mode = "clip", std::string name = "T_take",
+ std::string mode = "fast", std::string name = "T_take",
std::string tag = kInjective) {
Array<PrimExpr> a_shape = a->shape;
Array<PrimExpr> out_shape = indices->shape;
@@ -1032,6 +1032,16 @@ inline Tensor take(const Tensor& a, const Tensor&
indices, int batch_dims,
out_shape,
[&](const Array<Var>& out_index) { return
a(UnravelIndex(indices(out_index), a_shape)); },
name, tag);
+ } else if (mode == "nan") {
+ return compute(
+ out_shape,
+ [&](const Array<Var>& out_index) {
+ auto idx = tvm::if_then_else(
+ indices(out_index) < 0 || indices(out_index) >= a_size,
+ tvm::FloatImm(a->dtype,
std::numeric_limits<float>::quiet_NaN()), indices(out_index));
+ return a(UnravelIndex(idx, a_shape));
+ },
+ name, tag);
} else { // mode == "wrap"
return compute(
out_shape,
@@ -1094,7 +1104,7 @@ inline Tensor sequence_mask(const Tensor& data, const
Tensor& valid_length, doub
* \return A Tensor whose op member is the take operation
*/
inline Tensor take(const Tensor& a, Variant<Tensor, PrimExpr> indices, int
batch_dims, int axis,
- std::string mode = "clip", std::string name = "T_take",
+ std::string mode = "fast", std::string name = "T_take",
std::string tag = kInjective) {
if (axis < 0) {
axis += static_cast<int>(a->shape.size());
@@ -1206,6 +1216,8 @@ inline Tensor take(const Tensor& a, Variant<Tensor,
PrimExpr> indices, int batch
name, tag);
}
} else if (mode == "fast") {
+ LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices.
"
+ "Make sure input indices are in bound";
return compute(
out_shape,
[&](const Array<Var>& out_index) {
@@ -1224,6 +1236,29 @@ inline Tensor take(const Tensor& a, Variant<Tensor,
PrimExpr> indices, int batch
return a(real_indices);
},
name, tag);
+ } else if (mode == "nan") {
+ return compute(
+ out_shape,
+ [&](const Array<Var>& out_index) {
+ Array<PrimExpr> indices_position;
+ for (size_t j = axis; j < static_cast<size_t>(axis + indices_len);
++j) {
+ indices_position.push_back(out_index[j]);
+ }
+ Array<PrimExpr> real_indices;
+ for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
+ real_indices.push_back(out_index[j]);
+ }
+ PrimExpr idx = get_index(indices_position);
+ real_indices.push_back(idx);
+ for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
+ real_indices.push_back(out_index[j]);
+ }
+ PrimExpr in_bounds = idx >= 0 && idx < axis_dim;
+ return tvm::if_then_else(
+ in_bounds, a(real_indices),
+ tvm::tir::make_const(a->dtype,
std::numeric_limits<float>::quiet_NaN()));
+ },
+ name, tag);
} else { // mode == "wrap"
return compute(
out_shape,
diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py
index ec68bd585c..d4283259a8 100644
--- a/python/tvm/relax/op/index.py
+++ b/python/tvm/relax/op/index.py
@@ -26,7 +26,7 @@ from .. import args_converter
PrimExprLike = Union[int, PrimExpr]
-def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr:
+def take(x: Expr, indices: Expr, axis: Optional[int] = None, mode: str =
"fast") -> Expr:
"""Take elements from a tensor along an axis.
Its semantic is mostly similar to `numpy.take`
(https://numpy.org/doc/stable/reference/generated/numpy.take.html),
@@ -45,12 +45,18 @@ def take(x: Expr, indices: Expr, axis: Optional[int] =
None) -> Expr:
The axis over which to select values.
If it is none, the input tensor is required to be one-dimensional.
+ mode : str
+ Specifies how out-of-bounds indices will behave.
+ - fast (default): extra indices lead to seg fault (user must make sure
indices are in-bound)
+ - nan: produce NaNs for out-of-bounds indices
+ - wrap: wrap around the indices
+ - clip: clip to the range
Returns
-------
ret : relax.Expr
The taken result.
"""
- return _ffi_api.take(x, indices, axis) # type: ignore
+ return _ffi_api.take(x, indices, axis, mode) # type: ignore
@args_converter.auto
diff --git a/python/tvm/relax/transform/legalize_ops/index.py
b/python/tvm/relax/transform/legalize_ops/index.py
index 8d0ac535f6..d99c1f4db6 100644
--- a/python/tvm/relax/transform/legalize_ops/index.py
+++ b/python/tvm/relax/transform/legalize_ops/index.py
@@ -26,11 +26,9 @@ from .common import register_legalize
@register_legalize("relax.take")
def _take(bb: BlockBuilder, call: Call) -> Expr:
- # Currently Relax `take` operator doesn't provide the mode choices and
- # requires input indices to be in range.
- # We use fast mode, which leads to runtime error whenever some index is
- # out of bound.
- return bb.call_te(topi.take, call.args[0], call.args[1], call.attrs.axis,
mode="fast")
+ # Currently "fast" is the default mode, which leads to segmentation faults
+ # when there are out-of-bounds indices.
+ return bb.call_te(topi.take, call.args[0], call.args[1], call.attrs.axis,
mode=call.attrs.mode)
@register_legalize("relax.strided_slice")
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index bcb3ff95fa..98cec99a09 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -446,7 +446,7 @@ def split(ary, indices_or_sections, axis=0):
return cpp.split(ary, indices_or_sections, axis)
-def take(a, indices, axis=None, batch_dims=0, mode="clip"):
+def take(a, indices, axis=None, batch_dims=0, mode="fast"):
"""Take elements from an array along an axis.
Parameters
@@ -465,10 +465,11 @@ def take(a, indices, axis=None, batch_dims=0,
mode="clip"):
The number of batch dimensions. By default is 0.
mode : str, optional
- Specifies how out-of-bound indices will behave.
- clip - clip to the range (default)
- wrap - wrap around the indices
- fast - no clip or wrap around (user must make sure indices are
in-bound)
+ Specifies how out-of-bounds indices will behave.
+ - fast (default): extra indices lead to seg fault (user must make sure
indices are in-bound)
+ - nan: produce NaNs for out-of-bounds indices
+ - wrap: wrap around the indices
+ - clip: clip to the range
Returns
-------
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index e3626b8e3a..bbee9a502d 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -44,9 +44,10 @@ TVM_FFI_STATIC_INIT_BLOCK({
/* relax.take */
TVM_REGISTER_NODE_TYPE(TakeAttrs);
-Expr take(Expr x, Expr indices, Optional<int64_t> axis) {
+Expr take(Expr x, Expr indices, Optional<int64_t> axis, String mode) {
ObjectPtr<TakeAttrs> attrs = make_object<TakeAttrs>();
attrs->axis = std::move(axis);
+ attrs->mode = std::move(mode);
static const Op& op = Op::Get("relax.take");
return Call(op, {std::move(x), std::move(indices)}, Attrs(attrs), {});
@@ -100,8 +101,10 @@ StructInfo InferStructInfoTake(const Call& call, const
BlockBuilder& ctx) {
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim,
data_sinfo->vdevice);
}
- int axis =
- attrs->axis.has_value() ? NormalizeAxis(call, ctx, data_sinfo->ndim,
attrs->axis.value()) : 0;
+ int axis = 0;
+ if (attrs->axis.has_value()) {
+ axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value());
+ }
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
if (data_shape == nullptr || indices_shape == nullptr) {
diff --git a/src/relax/op/tensor/index.h b/src/relax/op/tensor/index.h
index 63a12e28f6..a45fb93792 100644
--- a/src/relax/op/tensor/index.h
+++ b/src/relax/op/tensor/index.h
@@ -38,9 +38,10 @@ namespace relax {
* It is required to be a one-dimensional tensor which has integer dtype.
* \param axis The axis over which to select values.
* If it is `std::nullopt`, the input tensor is required to be one-dimensional.
+ * \param mode The mode for handling out-of-bounds indices.
* \return The taken result.
*/
-Expr take(Expr x, Expr indices, Optional<int64_t> axis);
+Expr take(Expr x, Expr indices, Optional<int64_t> axis, String mode = "fast");
/*!
* \brief Strided slice of a tensor.
diff --git a/tests/python/relax/test_op_take.py
b/tests/python/relax/test_op_take.py
index 15ca5f4c89..704895d0e4 100644
--- a/tests/python/relax/test_op_take.py
+++ b/tests/python/relax/test_op_take.py
@@ -154,5 +154,85 @@ def test_take_dynamic_prim_value_as_index(target, dev,
axis):
tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
[email protected]_targets("llvm")
+def test_take_nan_mode_OOB_indices(target, dev, axis):
+ """Test R.take with mode="nan" and out-of-bounds indices.
+ This test checks that out-of-bounds indices produce NaN values in the
output tensor.
+ """
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(A: R.Tensor([3, 3], "float16")):
+ output = R.take(A, R.const([0, 1, 2, 3]), axis=axis, mode="nan")
+ return output
+
+ built = tvm.compile(Module, target=target)
+ vm = tvm.relax.VirtualMachine(built, dev)
+
+ np_input = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
dtype="float16")
+ tvm_input = tvm.nd.array(np_input, dev)
+ tvm_output = vm["main"](tvm_input)
+ if axis == 0:
+ np_expected = np.array(
+ [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [np.nan,
np.nan, np.nan]],
+ dtype="float16",
+ )
+ elif axis == 1:
+ np_expected = np.array(
+ [[1.0, 2.0, 3.0, np.nan], [4.0, 5.0, 6.0, np.nan], [7.0, 8.0, 9.0,
np.nan]],
+ dtype="float16",
+ )
+
+ tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
+
+
[email protected]_targets("llvm")
+def test_take_wrap_mode_OOB_indices(target, dev, axis):
+ """Test R.take with mode="wrap" and out-of-bounds indices.
+ This test checks that out-of-bounds indices wrap around to the valid range.
+ """
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(A: R.Tensor([3, 3], "float16")):
+ output = R.take(A, R.const([0, 1, 2, 3]), axis=axis, mode="wrap")
+ return output
+
+ built = tvm.compile(Module, target=target)
+ vm = tvm.relax.VirtualMachine(built, dev)
+
+ np_input = np.random.random(size=[3, 3]).astype("float16")
+ tvm_input = tvm.nd.array(np_input, dev)
+ tvm_output = vm["main"](tvm_input)
+ np_expected = np.take(np_input, [0, 1, 2, 3], axis=axis, mode="wrap")
+
+ tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
+
+
[email protected]_targets("llvm")
+def test_take_clip_mode_OOB_indices(target, dev, axis):
+ """Test R.take with mode="clip" and out-of-bounds indices.
+ This test checks that out-of-bounds indices are clipped to the valid range.
+ """
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(A: R.Tensor([3, 3], "float16")):
+ output = R.take(A, R.const([0, 1, 2, 3]), axis=axis, mode="clip")
+ return output
+
+ built = tvm.compile(Module, target=target)
+ vm = tvm.relax.VirtualMachine(built, dev)
+ np_input = np.random.random(size=[3, 3]).astype("float16")
+ tvm_input = tvm.nd.array(np_input, dev)
+ tvm_output = vm["main"](tvm_input)
+ np_expected = np.take(np_input, [0, 1, 2, 3], axis=axis, mode="clip")
+
+ tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
+
+
if __name__ == "__main__":
tvm.testing.main()