This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9eb8b3004b Add support for bucketize (#18040)
9eb8b3004b is described below

commit 9eb8b3004bf90d9082be680db5ce38121366c1d3
Author: kavin-mcw <[email protected]>
AuthorDate: Mon Jun 30 20:34:14 2025 +0530

    Add support for bucketize (#18040)
    
    * add support for bucketize
    
    * fix lint issue
    
    * Fix lint issue
    
    * Add GPU code for bucketize
    
    * Resolve merge conflict
    
    * Fix lint issue
---
 include/tvm/relax/attrs/search.h                   | 18 +++++
 python/tvm/relax/backend/dispatch_sort_scan.py     | 12 +++
 .../frontend/torch/base_fx_graph_translator.py     | 12 +++
 .../frontend/torch/exported_program_translator.py  |  1 +
 python/tvm/relax/frontend/torch/fx_translator.py   |  1 +
 python/tvm/relax/op/__init__.py                    |  2 +-
 python/tvm/relax/op/search.py                      | 25 ++++++
 python/tvm/relax/transform/legalize_ops/search.py  | 10 +++
 python/tvm/script/ir_builder/relax/ir.py           |  2 +
 python/tvm/topi/gpu/sort.py                        | 89 +++++++++++++++++++++-
 src/relax/op/tensor/search.cc                      | 53 ++++++++++++-
 src/relax/op/tensor/search.h                       | 10 +++
 .../relax/test_frontend_from_exported_program.py   | 25 ++++++
 tests/python/relax/test_frontend_from_fx.py        | 22 ++++++
 14 files changed, 279 insertions(+), 3 deletions(-)

diff --git a/include/tvm/relax/attrs/search.h b/include/tvm/relax/attrs/search.h
index 2abef5aee6..6fdbe59cea 100644
--- a/include/tvm/relax/attrs/search.h
+++ b/include/tvm/relax/attrs/search.h
@@ -49,6 +49,24 @@ struct ArgmaxArgminAttrs : public 
AttrsNodeReflAdapter<ArgmaxArgminAttrs> {
   TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ArgmaxArgminAttrs, BaseAttrsNode);
 };  // struct ArgmaxArgminAttrs
 
+/*! \brief Attributes for bucketize operator */
+struct BucketizeAttrs : public tvm::AttrsNodeReflAdapter<BucketizeAttrs> {
+  bool out_int32;
+  bool right;
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<BucketizeAttrs>()
+        .def_ro("out_int32", &BucketizeAttrs::out_int32,
+                "Indicate the output datatype, int32 if True, int64 
otherwise.")
+        .def_ro("right", &BucketizeAttrs::right,
+                "Determines the behavior for values in boundaries");
+  }
+
+  static constexpr const char* _type_key = "relax.attrs.BucketizeAttrs";
+  TVM_FFI_DECLARE_FINAL_OBJECT_INFO(BucketizeAttrs, BaseAttrsNode);
+};  // struct BucketizeAttrs
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py 
b/python/tvm/relax/backend/dispatch_sort_scan.py
index f8a7dfe203..1dac0bf230 100644
--- a/python/tvm/relax/backend/dispatch_sort_scan.py
+++ b/python/tvm/relax/backend/dispatch_sort_scan.py
@@ -75,6 +75,18 @@ class SortScanDispatcher(BackendDispatcher):
         if not isinstance(call.op, Op):
             return super().visit_call_(call)
 
+        if call.op.name == "relax.bucketize":
+            input_tensor = call.args[0]
+            boundaries = call.args[1]
+            right = call.attrs.right
+            tgt = self._get_target(call.struct_info)
+            te_func = topi.searchsorted
+            with tgt:
+                if self.is_gpu_target(tgt):
+                    te_func = topi.gpu.searchsorted
+            return self.builder_.call_te(
+                te_func, boundaries, input_tensor, right, 
input_tensor.struct_info.dtype
+            )
         if call.op.name == "relax.sort":
             tgt = self._get_target(call.struct_info)
             te_func = topi.sort
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 0026ae62a6..1895119e79 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1376,6 +1376,18 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         y = self.env[node.args[2]]
         return self.block_builder.emit(relax.op.where(condition, x, y))
 
+    def _bucketize(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        input_tensor = args[0]
+        boundaries = args[1]
+
+        right = node.kwargs.get("right", False)
+        out_int32 = node.kwargs.get("out_int32", False)
+
+        return self.block_builder.emit(
+            relax.op.bucketize(input_tensor, boundaries, out_int32, right)
+        )
+
     ########## Manipulation ##########
 
     def _argsort(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 1058647a4f..1a53a0cbdc 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -507,6 +507,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "argmax.default": self._argmax_argmin(relax.op.argmax),
             "argmin.default": self._argmax_argmin(relax.op.argmin),
             "where.self": self._where,
+            "bucketize.Tensor": self._bucketize,
             # tensor manipulation
             "argsort.default": self._argsort,
             "broadcast_to.default": self._broadcast_to,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 7dce09b0d2..754129ffde 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -917,6 +917,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "argmax": self._argmax_argmin(relax.op.argmax),
             "argmin": self._argmax_argmin(relax.op.argmin),
             "where": self._where,
+            "bucketize": self._bucketize,
             # tensor manipulation
             "argsort": self._argsort,
             "broadcast_to": self._broadcast_to,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 9388831fce..fd3672368b 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -115,7 +115,7 @@ from .manipulate import (
 from .mask import masked_fill
 from .qdq import dequantize, quantize
 from .sampling import multinomial_from_uniform
-from .search import argmax, argmin, where
+from .search import argmax, argmin, where, bucketize
 from .set import nonzero, unique
 from .sorting import argsort, sort, topk
 from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, 
variance
diff --git a/python/tvm/relax/op/search.py b/python/tvm/relax/op/search.py
index b097d78234..016b22b9b9 100644
--- a/python/tvm/relax/op/search.py
+++ b/python/tvm/relax/op/search.py
@@ -102,3 +102,28 @@ def argmin(x: Expr, axis: Optional[int] = None, keepdims: 
bool = False) -> Expr:
         The computed result.
     """
     return _ffi_api.argmin(x, axis, keepdims)  # type: ignore
+
+
+def bucketize(input_tensor, boundaries, out_int32=False, right=False):
+    """Returns the indices of the buckets to which each value in the input 
belongs.
+
+    Parameters
+    ----------
+    input_tensor : relax.Expr
+        N-D tensor containing the search values.
+
+    boundaries : relax.Expr
+        1-D tensor, must contain a strictly increasing sequence, or the return 
value is undefined.
+
+    out_int32 : Optional[bool]
+        Indicate the output data type. int32 if True, int64 otherwise. 
Default=False
+
+    right : Optional[bool]
+        Determines the behavior for values in boundaries. Similar to 
torch.bucketize
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result with same shape as input_tensor.
+    """
+    return _ffi_api.bucketize(input_tensor, boundaries, out_int32, right)
diff --git a/python/tvm/relax/transform/legalize_ops/search.py 
b/python/tvm/relax/transform/legalize_ops/search.py
index 19ff00774c..89fddb4b95 100644
--- a/python/tvm/relax/transform/legalize_ops/search.py
+++ b/python/tvm/relax/transform/legalize_ops/search.py
@@ -39,3 +39,13 @@ def _argmax_argmin(te_func: TEFunc) -> LegalizeFunc:
 
 register_legalize("relax.argmax", _argmax_argmin(topi.argmax))
 register_legalize("relax.argmin", _argmax_argmin(topi.argmin))
+
+
+@register_legalize("relax.bucketize")
+def _bucketize(bb, call):
+    input_tensor = call.args[0]
+    boundaries = call.args[1]
+    right = call.attrs.right
+    return bb.call_te(
+        topi.searchsorted, boundaries, input_tensor, right, 
input_tensor.struct_info.dtype
+    )
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 1e48e9ea1a..43590dfa25 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -58,6 +58,7 @@ from tvm.relax.op import (
     bitwise_or,
     bitwise_xor,
     broadcast_to,
+    bucketize,
     builtin,
     call_builtin_with_ctx,
     call_dps_packed,
@@ -731,6 +732,7 @@ __all__ = [
     "bitwise_or",
     "bitwise_xor",
     "broadcast_to",
+    "bucketize",
     "builtin",
     "call_inplace_packed",
     "call_packed",
diff --git a/python/tvm/topi/gpu/sort.py b/python/tvm/topi/gpu/sort.py
index 71854e4399..eb48da0a02 100644
--- a/python/tvm/topi/gpu/sort.py
+++ b/python/tvm/topi/gpu/sort.py
@@ -20,8 +20,9 @@ import tvm
 from tvm import te
 
 from ..transform import strided_slice, transpose
-from ..utils import ceil_div, swap
+from ..utils import ceil_div, swap, prod
 from ..math import cast, ceil_log2
+from ..searchsorted import binary_search
 
 
 def _get_threads(ib, nthread_tx, nthread_bx, nthread_by):
@@ -937,3 +938,89 @@ def topk_thrust(
         out = out[1]
 
     return out
+
+
+def searchsorted(sorted_sequence, values, right=False, out_dtype="int64"):
+    """Find indices where elements should be inserted to maintain order.
+       If `sorted_sequence` is N-dimensional, the innermost dimension of
+       `values` are searched in the corresponding dimension of 
`sorted_sequence`.
+
+       This implementation is optimized for GPU execution.
+
+    Parameters
+    ----------
+    sorted_sequence : te.Tensor
+        N-D or 1-D Tensor, containing monotonically increasing sequence
+        on the innermost dimension.
+
+    values : te.Tensor
+        N-D Tensor containing the search values. When `sorted_sequence` is 1-D,
+        the shape of `values` can be arbitrary. Otherwise, ranks of 
`sorted_sequence`
+        and `values` must be the same, and outer N-1 axes must have the same 
size.
+
+    right : bool, optional
+        Controls which index is returned if a value lands exactly on one of 
sorted values. If
+        False (side='left'), the index of the first suitable location found is 
given. If true
+        (side='right'), return the last such index.
+
+    out_dtype : string, optional
+        The data type of the output indices.
+
+    Returns
+    -------
+    indices : te.Tensor
+        Tensor with same shape as values, representing the indices of
+        elements of `values` if they are inserted in `sorted_sequence`.
+    """
+    if len(sorted_sequence.shape) > 1:
+        for i in range(len(values.shape) - 1):
+            assert (
+                values.shape[i] == sorted_sequence.shape[i]
+            ), "Outer dimensions of sorted_sequence and values must match for 
N-D searchsorted"
+
+    def ir(sorted_sequence_buf, values_buf, indices_buf):
+        ib = tvm.tir.ir_builder.create()
+        sorted_sequence_shape = sorted_sequence_buf.shape
+        values_shape = values_buf.shape
+        num_search = prod(values_shape)
+        search_range = sorted_sequence_shape[-1]
+
+        sorted_sequence_ptr = ib.buffer_ptr(sorted_sequence_buf)
+        values_ptr = ib.buffer_ptr(values_buf)
+        indices_ptr = ib.buffer_ptr(indices_buf)
+
+        max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+        nthread_tx = max_threads
+        nthread_bx = ceil_div(num_search, nthread_tx)
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * nthread_tx + tx
+
+        with ib.if_scope(tid < num_search):
+            if len(sorted_sequence_shape) == 1:
+                sequence_offset = 0
+            else:
+                sequence_id = tid // values_shape[-1]
+                sequence_offset = sequence_id * search_range
+
+            indices_ptr[tid] = binary_search(
+                ib,
+                sequence_offset,
+                search_range,
+                sorted_sequence_ptr,
+                values_ptr[tid],
+                right,
+                out_dtype,
+            )
+
+        return ib.get()
+
+    return te.extern(
+        values.shape,
+        [sorted_sequence, values],
+        lambda ins, outs: ir(ins[0], ins[1], outs[0]),
+        name="searchsorted_gpu",
+        dtype=out_dtype,
+    )
diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc
index 4ebf288d5a..3e0236fc28 100644
--- a/src/relax/op/tensor/search.cc
+++ b/src/relax/op/tensor/search.cc
@@ -30,7 +30,58 @@
 namespace tvm {
 namespace relax {
 
-TVM_FFI_STATIC_INIT_BLOCK({ ArgmaxArgminAttrs::RegisterReflection(); });
+TVM_FFI_STATIC_INIT_BLOCK({
+  ArgmaxArgminAttrs::RegisterReflection();
+  BucketizeAttrs::RegisterReflection();
+});
+
+/* relax.bucketize */
+TVM_REGISTER_NODE_TYPE(BucketizeAttrs);
+
+Expr bucketize(Expr input_tensor, Expr boundaries, bool out_int32, bool right) 
{
+  auto attrs = make_object<BucketizeAttrs>();
+  attrs->out_int32 = std::move(out_int32);
+  attrs->right = std::move(right);
+  static const Op& op = Op::Get("relax.bucketize");
+  return Call(op, {std::move(input_tensor), std::move(boundaries)}, 
Attrs(attrs), {});
+}
+
+TVM_FFI_REGISTER_GLOBAL("relax.op.bucketize").set_body_typed(bucketize);
+
+StructInfo InferStructInfoBucketize(const Call& call, const BlockBuilder& ctx) 
{
+  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+  TensorStructInfo input_tensor_info = input_sinfo[0];
+  TensorStructInfo boundaries_info = input_sinfo[1];
+
+  if (!boundaries_info->IsUnknownNdim() && boundaries_info->ndim != 1) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Bucketize requires boundary to be 1-D array but got "
+                     << boundaries_info->ndim);
+  }
+
+  auto attrs = call->attrs.as<BucketizeAttrs>();
+  DataType out_dtype;
+  out_dtype = DataType::Int(64);
+  if (attrs->out_int32) {
+    out_dtype = DataType::Int(32);
+  }
+
+  const auto* data_shape = input_tensor_info->shape.as<ShapeExprNode>();
+  if (data_shape) {
+    return TensorStructInfo(ShapeExpr(data_shape->values), out_dtype, 
input_tensor_info->vdevice);
+  }
+  return TensorStructInfo(out_dtype, input_tensor_info->ndim, 
input_tensor_info->vdevice);
+}
+
+TVM_REGISTER_OP("relax.bucketize")
+    .set_num_inputs(2)
+    .add_argument("input_tensor", "Tensor",
+                  " N-D tensor or a Scalar containing the search value(s).")
+    .add_argument("boundaries", "Tensor",
+                  "1-D tensor, must contain a strictly increasing sequence, or 
the return value is "
+                  "undefined.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoBucketize)
+    .set_attr<Bool>("FPurity", Bool(true));
 
 /* relax.where */
 Expr where(Expr condition, Expr x1, Expr x2) {
diff --git a/src/relax/op/tensor/search.h b/src/relax/op/tensor/search.h
index eb40171790..333b5afe76 100644
--- a/src/relax/op/tensor/search.h
+++ b/src/relax/op/tensor/search.h
@@ -30,6 +30,16 @@
 
 namespace tvm {
 namespace relax {
+/*!
+ * \brief Returns the indices of the buckets to which each value in the input 
belongs.
+ * \param input_tensor N-D tensor containing the search values.
+ * \param boundaries 1-D tensor, must contain a strictly increasing sequence.
+ * \param out_int32 Indicate the output data type. int32 if True, int64 
otherwise.
+ * \param right Determines the behavior for values in boundaries. Similar to 
torch.bucketize
+
+ * \return The computed result with the same shape as input.
+ */
+Expr bucketize(Expr input_tensor, Expr boundaries, bool out_int32, bool right);
 
 /*!
  * \brief Selecting elements from either the input tensors depending on the 
value of the
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index f0bdddbee3..406a5d9a1c 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5504,6 +5504,31 @@ def test_where():
     verify_model(Where(), (condition, x, y), {}, Expected)
 
 
+def test_bucketize():
+    class Bucketize(Module):
+        def forward(self, input_tensor, boundaries):
+            return torch.bucketize(input_tensor, boundaries)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            input: R.Tensor((20,), dtype="int64"), boundaries: R.Tensor((10,), 
dtype="int64")
+        ) -> R.Tuple(R.Tensor((20,), dtype="int64")):
+            with R.dataflow():
+                lv: R.Tensor((20,), dtype="int64") = R.bucketize(
+                    input, boundaries, out_int32=False, right=False
+                )
+                gv: R.Tuple(R.Tensor((20,), dtype="int64")) = (lv,)
+                R.output(gv)
+            return gv
+
+    input_tensor = torch.arange(0, 20)
+    boundaries = torch.arange(0, 20, 2)
+
+    verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected)
+
+
 def test_argsort():
     class Argsort(Module):
         def forward(self, x):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 00c61bd31f..47ca0819a9 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5874,6 +5874,28 @@ def test_where():
     )
 
 
+def test_bucketize():
+    class Bucketize(Module):
+        def forward(self, input_tensor, boundaries):
+            return torch.bucketize(input_tensor, boundaries)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            input: R.Tensor((5, 3), dtype="float32"), boundaries: 
R.Tensor((10,), dtype="float32")
+        ) -> R.Tensor((5, 3), dtype="int64"):
+            with R.dataflow():
+                lv: R.Tensor((5, 3), dtype="int64") = R.bucketize(
+                    input, boundaries, out_int32=False, right=False
+                )
+                gv: R.Tensor((5, 3), dtype="int64") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Bucketize(), [([5, 3], "float32"), ([10], "float32")], {}, 
Expected)
+
+
 def test_argsort():
     class Argsort(Module):
         def forward(self, x):

Reply via email to