This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 79368cea1a [Relax][ONNX][Transform] Add mode choice, new mode, and 
warning for take() (#18061)
79368cea1a is described below

commit 79368cea1afab6268b2bbbff96ea3d63a524130e
Author: Youngsik Yang <[email protected]>
AuthorDate: Tue Jul 8 20:58:35 2025 +0900

    [Relax][ONNX][Transform] Add mode choice, new mode, and warning for take() 
(#18061)
    
    [Relax][Transform] Add mode choice, NaN mode, and warning for take()
    
    - Add a `mode` parameter to Relax’s `take()`
    - Add `NaN` mode to `take()`
    - Add unit tests covering all `take()` modes
    - Add a warning log for `fast` mode
    - Unify default modes in lower layers to `fast` for consistency with Relax
---
 include/tvm/relax/attrs/index.h                  |  7 ++-
 include/tvm/topi/transform.h                     | 39 +++++++++++-
 python/tvm/relax/op/index.py                     | 10 ++-
 python/tvm/relax/transform/legalize_ops/index.py |  8 +--
 python/tvm/topi/transform.py                     | 11 ++--
 src/relax/op/tensor/index.cc                     |  9 ++-
 src/relax/op/tensor/index.h                      |  3 +-
 tests/python/relax/test_op_take.py               | 80 ++++++++++++++++++++++++
 8 files changed, 147 insertions(+), 20 deletions(-)

diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h
index 6f283f7044..cc914449db 100644
--- a/include/tvm/relax/attrs/index.h
+++ b/include/tvm/relax/attrs/index.h
@@ -32,11 +32,14 @@ namespace relax {
 /*! \brief Attributes used in take operator */
 struct TakeAttrs : public AttrsNodeReflAdapter<TakeAttrs> {
   Optional<int64_t> axis;
+  String mode;
 
   static void RegisterReflection() {
     namespace refl = tvm::ffi::reflection;
-    refl::ObjectDef<TakeAttrs>().def_ro("axis", &TakeAttrs::axis,
-                                        "The axis over which to select 
values.");
+    refl::ObjectDef<TakeAttrs>()
+        .def_ro("axis", &TakeAttrs::axis, "The axis over which to select 
values.")
+        .def_ro("mode", &TakeAttrs::mode, "The mode for handling out-of-bounds 
indices.",
+                refl::DefaultValue("fast"));
   }
 
   static constexpr const char* _type_key = "relax.attrs.TakeAttrs";
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index ac32ec6a42..df637f6f58 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -1008,7 +1008,7 @@ inline Array<Tensor> split_n_sections(const Tensor& x, 
int num_sections, int axi
  * \return A Tensor whose op member is the take operation
  */
 inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims,
-                   std::string mode = "clip", std::string name = "T_take",
+                   std::string mode = "fast", std::string name = "T_take",
                    std::string tag = kInjective) {
   Array<PrimExpr> a_shape = a->shape;
   Array<PrimExpr> out_shape = indices->shape;
@@ -1032,6 +1032,16 @@ inline Tensor take(const Tensor& a, const Tensor& 
indices, int batch_dims,
         out_shape,
         [&](const Array<Var>& out_index) { return 
a(UnravelIndex(indices(out_index), a_shape)); },
         name, tag);
+  } else if (mode == "nan") {
+    return compute(
+        out_shape,
+        [&](const Array<Var>& out_index) {
+          auto idx = tvm::if_then_else(
+              indices(out_index) < 0 || indices(out_index) >= a_size,
+              tvm::FloatImm(a->dtype, 
std::numeric_limits<float>::quiet_NaN()), indices(out_index));
+          return a(UnravelIndex(idx, a_shape));
+        },
+        name, tag);
   } else {  // mode == "wrap"
     return compute(
         out_shape,
@@ -1094,7 +1104,7 @@ inline Tensor sequence_mask(const Tensor& data, const 
Tensor& valid_length, doub
  * \return A Tensor whose op member is the take operation
  */
 inline Tensor take(const Tensor& a, Variant<Tensor, PrimExpr> indices, int 
batch_dims, int axis,
-                   std::string mode = "clip", std::string name = "T_take",
+                   std::string mode = "fast", std::string name = "T_take",
                    std::string tag = kInjective) {
   if (axis < 0) {
     axis += static_cast<int>(a->shape.size());
@@ -1206,6 +1216,8 @@ inline Tensor take(const Tensor& a, Variant<Tensor, 
PrimExpr> indices, int batch
           name, tag);
     }
   } else if (mode == "fast") {
+    LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. 
"
+                    "Make sure input indices are in bound";
     return compute(
         out_shape,
         [&](const Array<Var>& out_index) {
@@ -1224,6 +1236,29 @@ inline Tensor take(const Tensor& a, Variant<Tensor, 
PrimExpr> indices, int batch
           return a(real_indices);
         },
         name, tag);
+  } else if (mode == "nan") {
+    return compute(
+        out_shape,
+        [&](const Array<Var>& out_index) {
+          Array<PrimExpr> indices_position;
+          for (size_t j = axis; j < static_cast<size_t>(axis + indices_len); 
++j) {
+            indices_position.push_back(out_index[j]);
+          }
+          Array<PrimExpr> real_indices;
+          for (size_t j = 0; j < static_cast<size_t>(axis); ++j) {
+            real_indices.push_back(out_index[j]);
+          }
+          PrimExpr idx = get_index(indices_position);
+          real_indices.push_back(idx);
+          for (size_t j = axis + indices_len; j < out_index.size(); ++j) {
+            real_indices.push_back(out_index[j]);
+          }
+          PrimExpr in_bounds = idx >= 0 && idx < axis_dim;
+          return tvm::if_then_else(
+              in_bounds, a(real_indices),
+              tvm::tir::make_const(a->dtype, 
std::numeric_limits<float>::quiet_NaN()));
+        },
+        name, tag);
   } else {  // mode == "wrap"
     return compute(
         out_shape,
diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py
index ec68bd585c..d4283259a8 100644
--- a/python/tvm/relax/op/index.py
+++ b/python/tvm/relax/op/index.py
@@ -26,7 +26,7 @@ from .. import args_converter
 PrimExprLike = Union[int, PrimExpr]
 
 
-def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr:
+def take(x: Expr, indices: Expr, axis: Optional[int] = None, mode: str = 
"fast") -> Expr:
     """Take elements from a tensor along an axis.
     Its semantic is mostly similar to `numpy.take`
     (https://numpy.org/doc/stable/reference/generated/numpy.take.html),
@@ -45,12 +45,18 @@ def take(x: Expr, indices: Expr, axis: Optional[int] = 
None) -> Expr:
         The axis over which to select values.
         If it is none, the input tensor is required to be one-dimensional.
 
+    mode : str
+        Specifies how out-of-bounds indices will behave.
+        - fast (default): extra indices lead to seg fault (user must make sure 
indices are in-bound)
+        - nan: produce NaNs for out-of-bounds indices
+        - wrap: wrap around the indices
+        - clip: clip to the range
     Returns
     -------
     ret : relax.Expr
         The taken result.
     """
-    return _ffi_api.take(x, indices, axis)  # type: ignore
+    return _ffi_api.take(x, indices, axis, mode)  # type: ignore
 
 
 @args_converter.auto
diff --git a/python/tvm/relax/transform/legalize_ops/index.py 
b/python/tvm/relax/transform/legalize_ops/index.py
index 8d0ac535f6..d99c1f4db6 100644
--- a/python/tvm/relax/transform/legalize_ops/index.py
+++ b/python/tvm/relax/transform/legalize_ops/index.py
@@ -26,11 +26,9 @@ from .common import register_legalize
 
 @register_legalize("relax.take")
 def _take(bb: BlockBuilder, call: Call) -> Expr:
-    # Currently Relax `take` operator doesn't provide the mode choices and
-    # requires input indices to be in range.
-    # We use fast mode, which leads to runtime error whenever some index is
-    # out of bound.
-    return bb.call_te(topi.take, call.args[0], call.args[1], call.attrs.axis, 
mode="fast")
+    # Currently "fast" is the default mode, which leads to segmentation faults
+    # when there are out-of-bounds indices.
+    return bb.call_te(topi.take, call.args[0], call.args[1], call.attrs.axis, 
mode=call.attrs.mode)
 
 
 @register_legalize("relax.strided_slice")
diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py
index bcb3ff95fa..98cec99a09 100644
--- a/python/tvm/topi/transform.py
+++ b/python/tvm/topi/transform.py
@@ -446,7 +446,7 @@ def split(ary, indices_or_sections, axis=0):
     return cpp.split(ary, indices_or_sections, axis)
 
 
-def take(a, indices, axis=None, batch_dims=0, mode="clip"):
+def take(a, indices, axis=None, batch_dims=0, mode="fast"):
     """Take elements from an array along an axis.
 
     Parameters
@@ -465,10 +465,11 @@ def take(a, indices, axis=None, batch_dims=0, 
mode="clip"):
         The number of batch dimensions. By default is 0.
 
     mode : str, optional
-        Specifies how out-of-bound indices will behave.
-        clip - clip to the range (default)
-        wrap - wrap around the indices
-        fast - no clip or wrap around (user must make sure indices are 
in-bound)
+        Specifies how out-of-bounds indices will behave.
+        - fast (default): extra indices lead to seg fault (user must make sure 
indices are in-bound)
+        - nan: produce NaNs for out-of-bounds indices
+        - wrap: wrap around the indices
+        - clip: clip to the range
 
     Returns
     -------
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index e3626b8e3a..bbee9a502d 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -44,9 +44,10 @@ TVM_FFI_STATIC_INIT_BLOCK({
 /* relax.take */
 TVM_REGISTER_NODE_TYPE(TakeAttrs);
 
-Expr take(Expr x, Expr indices, Optional<int64_t> axis) {
+Expr take(Expr x, Expr indices, Optional<int64_t> axis, String mode) {
   ObjectPtr<TakeAttrs> attrs = make_object<TakeAttrs>();
   attrs->axis = std::move(axis);
+  attrs->mode = std::move(mode);
 
   static const Op& op = Op::Get("relax.take");
   return Call(op, {std::move(x), std::move(indices)}, Attrs(attrs), {});
@@ -100,8 +101,10 @@ StructInfo InferStructInfoTake(const Call& call, const 
BlockBuilder& ctx) {
     return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, 
data_sinfo->vdevice);
   }
 
-  int axis =
-      attrs->axis.has_value() ? NormalizeAxis(call, ctx, data_sinfo->ndim, 
attrs->axis.value()) : 0;
+  int axis = 0;
+  if (attrs->axis.has_value()) {
+    axis = NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value());
+  }
   const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
   const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
   if (data_shape == nullptr || indices_shape == nullptr) {
diff --git a/src/relax/op/tensor/index.h b/src/relax/op/tensor/index.h
index 63a12e28f6..a45fb93792 100644
--- a/src/relax/op/tensor/index.h
+++ b/src/relax/op/tensor/index.h
@@ -38,9 +38,10 @@ namespace relax {
  * It is required to be a one-dimensional tensor which has integer dtype.
  * \param axis The axis over which to select values.
  * If it is `std::nullopt`, the input tensor is required to be one-dimensional.
+ * \param mode The mode for handling out-of-bounds indices.
  * \return The taken result.
  */
-Expr take(Expr x, Expr indices, Optional<int64_t> axis);
+Expr take(Expr x, Expr indices, Optional<int64_t> axis, String mode = "fast");
 
 /*!
  * \brief Strided slice of a tensor.
diff --git a/tests/python/relax/test_op_take.py 
b/tests/python/relax/test_op_take.py
index 15ca5f4c89..704895d0e4 100644
--- a/tests/python/relax/test_op_take.py
+++ b/tests/python/relax/test_op_take.py
@@ -154,5 +154,85 @@ def test_take_dynamic_prim_value_as_index(target, dev, 
axis):
     tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
 
 
[email protected]_targets("llvm")
+def test_take_nan_mode_OOB_indices(target, dev, axis):
+    """Test R.take with mode="nan" and out-of-bounds indices.
+    This test checks that out-of-bounds indices produce NaN values in the 
output tensor.
+    """
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(A: R.Tensor([3, 3], "float16")):
+            output = R.take(A, R.const([0, 1, 2, 3]), axis=axis, mode="nan")
+            return output
+
+    built = tvm.compile(Module, target=target)
+    vm = tvm.relax.VirtualMachine(built, dev)
+
+    np_input = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], 
dtype="float16")
+    tvm_input = tvm.nd.array(np_input, dev)
+    tvm_output = vm["main"](tvm_input)
+    if axis == 0:
+        np_expected = np.array(
+            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [np.nan, 
np.nan, np.nan]],
+            dtype="float16",
+        )
+    elif axis == 1:
+        np_expected = np.array(
+            [[1.0, 2.0, 3.0, np.nan], [4.0, 5.0, 6.0, np.nan], [7.0, 8.0, 9.0, 
np.nan]],
+            dtype="float16",
+        )
+
+    tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
+
+
[email protected]_targets("llvm")
+def test_take_wrap_mode_OOB_indices(target, dev, axis):
+    """Test R.take with mode="wrap" and out-of-bounds indices.
+    This test checks that out-of-bounds indices wrap around to the valid range.
+    """
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(A: R.Tensor([3, 3], "float16")):
+            output = R.take(A, R.const([0, 1, 2, 3]), axis=axis, mode="wrap")
+            return output
+
+    built = tvm.compile(Module, target=target)
+    vm = tvm.relax.VirtualMachine(built, dev)
+
+    np_input = np.random.random(size=[3, 3]).astype("float16")
+    tvm_input = tvm.nd.array(np_input, dev)
+    tvm_output = vm["main"](tvm_input)
+    np_expected = np.take(np_input, [0, 1, 2, 3], axis=axis, mode="wrap")
+
+    tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
+
+
[email protected]_targets("llvm")
+def test_take_clip_mode_OOB_indices(target, dev, axis):
+    """Test R.take with mode="clip" and out-of-bounds indices.
+    This test checks that out-of-bounds indices are clipped to the valid range.
+    """
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(A: R.Tensor([3, 3], "float16")):
+            output = R.take(A, R.const([0, 1, 2, 3]), axis=axis, mode="clip")
+            return output
+
+    built = tvm.compile(Module, target=target)
+    vm = tvm.relax.VirtualMachine(built, dev)
+    np_input = np.random.random(size=[3, 3]).astype("float16")
+    tvm_input = tvm.nd.array(np_input, dev)
+    tvm_output = vm["main"](tvm_input)
+    np_expected = np.take(np_input, [0, 1, 2, 3], axis=axis, mode="clip")
+
+    tvm.testing.assert_allclose(tvm_output.numpy(), np_expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to