This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 38eb79c63f [Relax][Vision] Add get_valid_counts and classic NMS 
(#18943)
38eb79c63f is described below

commit 38eb79c63f1514402e75efdf71ab1868532115c8
Author: HoYi <[email protected]>
AuthorDate: Sun Mar 29 01:24:27 2026 +0800

    [Relax][Vision] Add get_valid_counts and classic NMS (#18943)
    
    ## Summary
    
    Add `relax.vision.get_valid_counts` and classic
    `relax.vision.non_max_suppression` for object-detection post-processing
    pipelines.
    
    `get_valid_counts` performs score-based bounding box filtering and
    compacts valid boxes to the front of each batch. Classic
    `non_max_suppression` performs flexible IoU-based suppression on
    filtered boxes, complementing existing `all_class_non_max_suppression`
    for custom post-processing workflows.
    
    This PR implements the Relax-level registration, legalization, TOPI
    compute, and test coverage for both operators.
    
    ## Changes
    
    **Relax op registration and legalization:**
    - C++ op functions, FFI registration, and struct info inference for both
    operators (`vision.h`, `vision.cc`)
    - Python wrappers with Relax docstrings (`vision.py`)
    - Legalization to `topi.vision.get_valid_counts` and
    `topi.vision.non_max_suppression`
    - Additional struct-info validation for `score_index`, `id_index`, and
    `coord_start` when `elem_length` is statically known
    
    **TOPI and testing:**
    - Full TOPI implementation for `get_valid_counts`
    - Reimplementation of classic `non_max_suppression` in TOPI
    - NumPy reference implementations in `tvm.topi.testing` for both
    operators
    - Op-level tests for struct info inference, legalization, invalid
    attribute ranges, and e2e numerical correctness
    - Stronger legalization tests that verify both `relax.call_tir`
    introduction and removal of the original Relax vision op
    
    ## Limitations
    
    - Attribute range validation for `score_index`, `id_index`, and
    `coord_start` is only enforced when the input `elem_length` is
    statically known during struct-info inference.
    - Classic `non_max_suppression` follows the existing Relax / TOPI API
    shape and is intended for single-class or class-aware custom
    post-processing flows, distinct from `all_class_non_max_suppression`.
    
    ## Validation
    
    ```bash
    pytest tests/python/relax/test_op_vision.py -k "get_valid_counts" -v
    pytest tests/python/relax/test_op_vision.py -k "test_nms_" -v
    ```
    All related tests passed.
---
 include/tvm/relax/attrs/vision.h                   |  60 ++
 python/tvm/relax/op/__init__.py                    |   8 +-
 python/tvm/relax/op/op_attrs.py                    |  10 +
 python/tvm/relax/op/vision/nms.py                  | 114 ++-
 python/tvm/relax/transform/legalize_ops/vision.py  |  30 +
 python/tvm/topi/testing/__init__.py                |   2 +
 python/tvm/topi/testing/get_valid_counts_python.py |  68 ++
 python/tvm/topi/testing/nms_python.py              | 146 ++++
 python/tvm/topi/vision/nms.py                      | 503 +++++++++++-
 python/tvm/topi/vision/nms_util.py                 |   4 +-
 src/relax/op/vision/nms.cc                         | 244 +++++-
 src/relax/op/vision/nms.h                          |   9 +
 tests/python/relax/test_op_vision.py               | 862 ++++++++++++++++++++-
 .../relax/test_tvmscript_parser_op_vision.py       | 132 ++++
 14 files changed, 2166 insertions(+), 26 deletions(-)

diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h
index 4e3351bb90..69ce458e7e 100644
--- a/include/tvm/relax/attrs/vision.h
+++ b/include/tvm/relax/attrs/vision.h
@@ -73,6 +73,66 @@ struct ROIAlignAttrs : public 
AttrsNodeReflAdapter<ROIAlignAttrs> {
   TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs", 
ROIAlignAttrs, BaseAttrsNode);
 };  // struct ROIAlignAttrs
 
+/*! \brief Attributes used in GetValidCounts operator */
+struct GetValidCountsAttrs : public AttrsNodeReflAdapter<GetValidCountsAttrs> {
+  double score_threshold;
+  int id_index;
+  int score_index;
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<GetValidCountsAttrs>()
+        .def_ro("score_threshold", &GetValidCountsAttrs::score_threshold,
+                "Lower limit of score for valid bounding boxes.")
+        .def_ro("id_index", &GetValidCountsAttrs::id_index,
+                "Index of the class categories, -1 to disable.")
+        .def_ro("score_index", &GetValidCountsAttrs::score_index,
+                "Index of the scores/confidence of boxes.");
+  }
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GetValidCountsAttrs", 
GetValidCountsAttrs,
+                                    BaseAttrsNode);
+};  // struct GetValidCountsAttrs
+
+/*! \brief Attributes used in NonMaximumSuppression operator */
+struct NonMaximumSuppressionAttrs
+    : public AttrsNodeReflAdapter<NonMaximumSuppressionAttrs> {
+  int max_output_size;
+  double iou_threshold;
+  bool force_suppress;
+  int top_k;
+  int coord_start;
+  int score_index;
+  int id_index;
+  bool return_indices;
+  bool invalid_to_bottom;
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<NonMaximumSuppressionAttrs>()
+        .def_ro("max_output_size", 
&NonMaximumSuppressionAttrs::max_output_size,
+                "Max number of output valid boxes, -1 for no limit.")
+        .def_ro("iou_threshold", &NonMaximumSuppressionAttrs::iou_threshold,
+                "Non-maximum suppression IoU threshold.")
+        .def_ro("force_suppress", &NonMaximumSuppressionAttrs::force_suppress,
+                "Whether to suppress all detections regardless of class_id.")
+        .def_ro("top_k", &NonMaximumSuppressionAttrs::top_k,
+                "Keep maximum top k detections before nms, -1 for no limit.")
+        .def_ro("coord_start", &NonMaximumSuppressionAttrs::coord_start,
+                "Start index of the consecutive 4 coordinates.")
+        .def_ro("score_index", &NonMaximumSuppressionAttrs::score_index,
+                "Index of the scores/confidence of boxes.")
+        .def_ro("id_index", &NonMaximumSuppressionAttrs::id_index,
+                "Index of the class categories, -1 to disable.")
+        .def_ro("return_indices", &NonMaximumSuppressionAttrs::return_indices,
+                "Whether to return box indices in input data.")
+        .def_ro("invalid_to_bottom", 
&NonMaximumSuppressionAttrs::invalid_to_bottom,
+                "Whether to move all valid bounding boxes to the top.");
+  }
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.NonMaximumSuppressionAttrs",
+                                    NonMaximumSuppressionAttrs, BaseAttrsNode);
+};  // struct NonMaximumSuppressionAttrs
+
+
 /*! \brief Attributes for multibox_transform_loc (SSD / TFLite-style box 
decode). */
 struct MultiboxTransformLocAttrs : public 
AttrsNodeReflAdapter<MultiboxTransformLocAttrs> {
   bool clip;
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index ee1a2c2420..0b8dc4e7de 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -157,7 +157,13 @@ from .unary import (
     tanh,
     trunc,
 )
-from .vision import all_class_non_max_suppression, multibox_transform_loc, 
roi_align
+from .vision import (
+    all_class_non_max_suppression,
+    get_valid_counts,
+    multibox_transform_loc,
+    non_max_suppression,
+    roi_align,
+)
 
 
 def _register_op_make():
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index e8c91f04b4..d85c439d3a 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -246,6 +246,16 @@ class AllClassNonMaximumSuppressionAttrs(Attrs):
     """Attributes for vision.all_class_non_max_suppression"""
 
 
+@tvm_ffi.register_object("relax.attrs.GetValidCountsAttrs")
+class GetValidCountsAttrs(Attrs):
+    """Attributes for vision.get_valid_counts"""
+
+
+@tvm_ffi.register_object("relax.attrs.NonMaximumSuppressionAttrs")
+class NonMaximumSuppressionAttrs(Attrs):
+    """Attributes for vision.non_max_suppression"""
+
+
 @tvm_ffi.register_object("relax.attrs.ROIAlignAttrs")
 class ROIAlignAttrs(Attrs):
     """Attributes for vision.roi_align"""
diff --git a/python/tvm/relax/op/vision/nms.py 
b/python/tvm/relax/op/vision/nms.py
index 616c74ddf6..4eb3eb7f7a 100644
--- a/python/tvm/relax/op/vision/nms.py
+++ b/python/tvm/relax/op/vision/nms.py
@@ -14,9 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Non-maximum suppression operator"""
+"""Non-maximum suppression operators."""
 
-# from tvm import relax  # Unused import
 from . import _ffi_api
 
 
@@ -72,3 +71,114 @@ def all_class_non_max_suppression(
     return _ffi_api.all_class_non_max_suppression(
         boxes, scores, max_output_boxes_per_class, iou_threshold, 
score_threshold, output_format
     )
+
+
+def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
+    """Get valid count of bounding boxes given a score threshold.
+    Also moves valid boxes to the top of input data.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        3-D tensor with shape [batch_size, num_anchors, elem_length].
+
+    score_threshold : float, optional
+        Lower limit of score for valid bounding boxes.
+
+    id_index : int, optional
+        Index of the class categories. Set to ``-1`` to disable the class-id 
check.
+
+    score_index : int, optional
+        Index of the scores/confidence of boxes.
+
+    Returns
+    -------
+    out : relax.Expr
+        A tuple ``(valid_count, out_tensor, out_indices)`` where 
``valid_count``
+        has shape ``[batch_size]``, ``out_tensor`` has shape
+        ``[batch_size, num_anchors, elem_length]``, and ``out_indices`` has 
shape
+        ``[batch_size, num_anchors]``.
+    """
+    return _ffi_api.get_valid_counts(data, score_threshold, id_index, 
score_index)
+
+
+def non_max_suppression(
+    data,
+    valid_count,
+    indices,
+    max_output_size=-1,
+    iou_threshold=0.5,
+    force_suppress=False,
+    top_k=-1,
+    coord_start=2,
+    score_index=1,
+    id_index=0,
+    return_indices=True,
+    invalid_to_bottom=False,
+):
+    """Non-maximum suppression operator for object detection.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        3-D tensor with shape [batch_size, num_anchors, elem_length].
+
+    valid_count : relax.Expr
+        1-D tensor for valid number of boxes.
+
+    indices : relax.Expr
+        2-D tensor with shape [batch_size, num_anchors].
+
+    max_output_size : int, optional
+        Max number of output valid boxes, -1 for no limit.
+
+    iou_threshold : float, optional
+        Non-maximum suppression IoU threshold.
+
+    force_suppress : bool, optional
+        Whether to suppress all detections regardless of class_id. When
+        ``id_index`` is ``-1``, all valid boxes are treated as belonging to the
+        same class, so this flag has the same effect as ``True``.
+
+    top_k : int, optional
+        Keep maximum top k detections before nms, -1 for no limit.
+
+    coord_start : int, optional
+        Start index of the consecutive 4 coordinates.
+
+    score_index : int, optional
+        Index of the scores/confidence of boxes.
+
+    id_index : int, optional
+        Index of the class categories. Set to ``-1`` to suppress boxes across
+        all classes.
+
+    return_indices : bool, optional
+        Whether to return box indices in input data.
+
+    invalid_to_bottom : bool, optional
+        Whether to move valid bounding boxes to the top of the returned tensor.
+        This option only affects the ``return_indices=False`` path.
+
+    Returns
+    -------
+    out : relax.Expr
+        If ``return_indices`` is ``True``, returns
+        ``(box_indices, valid_box_count)`` with shapes
+        ``[batch_size, num_anchors]`` and ``[batch_size, 1]``.
+        Otherwise returns the modified data tensor.
+    """
+    return _ffi_api.non_max_suppression(
+        data,
+        valid_count,
+        indices,
+        max_output_size,
+        iou_threshold,
+        force_suppress,
+        top_k,
+        coord_start,
+        score_index,
+        id_index,
+        return_indices,
+        invalid_to_bottom,
+    )
diff --git a/python/tvm/relax/transform/legalize_ops/vision.py 
b/python/tvm/relax/transform/legalize_ops/vision.py
index 28367a67a3..ea0458bfce 100644
--- a/python/tvm/relax/transform/legalize_ops/vision.py
+++ b/python/tvm/relax/transform/legalize_ops/vision.py
@@ -120,6 +120,36 @@ def _roi_align(bb: BlockBuilder, call: Call) -> Expr:
     )
 
 
+@register_legalize("relax.vision.get_valid_counts")
+def _get_valid_counts(block_builder: BlockBuilder, call: Call) -> Expr:
+    return block_builder.call_te(
+        topi.vision.get_valid_counts,
+        call.args[0],
+        score_threshold=call.attrs.score_threshold,
+        id_index=call.attrs.id_index,
+        score_index=call.attrs.score_index,
+    )
+
+
+@register_legalize("relax.vision.non_max_suppression")
+def _non_max_suppression(block_builder: BlockBuilder, call: Call) -> Expr:
+    return block_builder.call_te(
+        topi.vision.non_max_suppression,
+        call.args[0],
+        call.args[1],
+        call.args[2],
+        max_output_size=call.attrs.max_output_size,
+        iou_threshold=call.attrs.iou_threshold,
+        force_suppress=call.attrs.force_suppress,
+        top_k=call.attrs.top_k,
+        coord_start=call.attrs.coord_start,
+        score_index=call.attrs.score_index,
+        id_index=call.attrs.id_index,
+        return_indices=call.attrs.return_indices,
+        invalid_to_bottom=call.attrs.invalid_to_bottom,
+    )
+
+
 @register_legalize("relax.vision.multibox_transform_loc")
 def _multibox_transform_loc(bb: BlockBuilder, call: Call) -> Expr:
     variances = tuple(float(x) for x in call.attrs.variances)
diff --git a/python/tvm/topi/testing/__init__.py 
b/python/tvm/topi/testing/__init__.py
index d9fd005921..143ccb8459 100644
--- a/python/tvm/topi/testing/__init__.py
+++ b/python/tvm/topi/testing/__init__.py
@@ -54,9 +54,11 @@ from .lrn_python import lrn_python
 from .l2_normalize_python import l2_normalize_python
 from .gather_python import gather_python
 from .gather_nd_python import gather_nd_python
+from .get_valid_counts_python import get_valid_counts_python
 from .strided_slice_python import strided_slice_python, strided_set_python
 from .batch_matmul import batch_matmul
 from .batch_norm import batch_norm
+from .nms_python import non_max_suppression_python
 from .slice_axis_python import slice_axis_python
 from .sequence_mask_python import sequence_mask
 from .poolnd_python import poolnd_python
diff --git a/python/tvm/topi/testing/get_valid_counts_python.py 
b/python/tvm/topi/testing/get_valid_counts_python.py
new file mode 100644
index 0000000000..2caab6babc
--- /dev/null
+++ b/python/tvm/topi/testing/get_valid_counts_python.py
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Numpy reference implementation for get_valid_counts."""
+import numpy as np
+
+
+def get_valid_counts_python(data, score_threshold=0, id_index=0, 
score_index=1):
+    """Numpy reference for get_valid_counts.
+
+    Parameters
+    ----------
+    data : numpy.ndarray
+        3-D array with shape [batch_size, num_anchors, elem_length].
+
+    score_threshold : float
+        Lower limit of score for valid bounding boxes.
+
+    id_index : int
+        Index of the class categories, -1 to disable.
+
+    score_index : int
+        Index of the scores/confidence of boxes.
+
+    Returns
+    -------
+    valid_count : numpy.ndarray
+        1-D array, shape [batch_size].
+
+    out_tensor : numpy.ndarray
+        Rearranged data, shape [batch_size, num_anchors, elem_length].
+
+    out_indices : numpy.ndarray
+        Indices mapping, shape [batch_size, num_anchors].
+    """
+    batch_size, num_anchors, box_data_length = data.shape
+    valid_count = np.zeros(batch_size, dtype="int32")
+    out_tensor = np.full_like(data, -1.0)
+    out_indices = np.full((batch_size, num_anchors), -1, dtype="int32")
+
+    for i in range(batch_size):
+        cnt = 0
+        for j in range(num_anchors):
+            score = data[i, j, score_index]
+            if id_index < 0:
+                is_valid = score > score_threshold
+            else:
+                is_valid = score > score_threshold and data[i, j, id_index] >= 0
+            if is_valid:
+                out_tensor[i, cnt, :] = data[i, j, :]
+                out_indices[i, cnt] = j
+                cnt += 1
+        valid_count[i] = cnt
+
+    return valid_count, out_tensor, out_indices
diff --git a/python/tvm/topi/testing/nms_python.py 
b/python/tvm/topi/testing/nms_python.py
new file mode 100644
index 0000000000..7c8c20f5b4
--- /dev/null
+++ b/python/tvm/topi/testing/nms_python.py
@@ -0,0 +1,146 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Numpy reference implementation for classic non_max_suppression."""
+import numpy as np
+
+
+def _iou(box_a, box_b, coord_start):
+    """Compute IoU between two boxes."""
+    a = box_a[coord_start : coord_start + 4]
+    b = box_b[coord_start : coord_start + 4]
+
+    a_l, a_t, a_r, a_b = min(a[0], a[2]), min(a[1], a[3]), max(a[0], a[2]), 
max(a[1], a[3])
+    b_l, b_t, b_r, b_b = min(b[0], b[2]), min(b[1], b[3]), max(b[0], b[2]), 
max(b[1], b[3])
+
+    w = max(0.0, min(a_r, b_r) - max(a_l, b_l))
+    h = max(0.0, min(a_b, b_b) - max(a_t, b_t))
+    area = w * h
+    u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area
+    return 0.0 if u <= 0 else area / u
+
+
+def non_max_suppression_python(
+    data,
+    valid_count,
+    indices,
+    max_output_size=-1,
+    iou_threshold=0.5,
+    force_suppress=False,
+    top_k=-1,
+    coord_start=2,
+    score_index=1,
+    id_index=0,
+    return_indices=True,
+    invalid_to_bottom=False,
+):
+    """Numpy reference for classic non_max_suppression.
+
+    Parameters
+    ----------
+    data : numpy.ndarray
+        3-D array, shape [batch_size, num_anchors, elem_length].
+
+    valid_count : numpy.ndarray
+        1-D array, shape [batch_size].
+
+    indices : numpy.ndarray
+        2-D array, shape [batch_size, num_anchors].
+
+    Returns
+    -------
+    If return_indices is True: (box_indices, valid_box_count)
+    Otherwise: modified data tensor
+    """
+    batch_size, num_anchors, _ = data.shape
+    out_data = np.full_like(data, -1.0)
+    out_box_indices = np.full((batch_size, num_anchors), -1, dtype="int32")
+    compacted = np.full((batch_size, num_anchors), -1, dtype="int32")
+    valid_box_count = np.zeros((batch_size, 1), dtype="int32")
+
+    for i in range(batch_size):
+        nkeep = int(valid_count[i])
+        if 0 < top_k < nkeep:
+            nkeep = top_k
+
+        # Sort by score descending
+        scores = data[i, :nkeep, score_index].copy()
+        sorted_idx = np.argsort(-scores)
+
+        # Copy sorted boxes
+        for j in range(nkeep):
+            src = sorted_idx[j]
+            out_data[i, j, :] = data[i, src, :]
+            out_box_indices[i, j] = src
+
+        # Greedy NMS
+        num_valid = 0
+        for j in range(nkeep):
+            if out_data[i, j, score_index] <= 0:
+                out_data[i, j, :] = -1.0
+                out_box_indices[i, j] = -1
+                continue
+            if 0 < max_output_size <= num_valid:
+                out_data[i, j, :] = -1.0
+                out_box_indices[i, j] = -1
+                continue
+
+            num_valid += 1
+
+            # Suppress overlapping boxes
+            for k in range(j + 1, nkeep):
+                if out_data[i, k, score_index] <= 0:
+                    continue
+
+                do_suppress = False
+                if force_suppress:
+                    do_suppress = True
+                elif id_index >= 0:
+                    do_suppress = out_data[i, j, id_index] == out_data[i, k, 
id_index]
+                else:
+                    do_suppress = True
+
+                if do_suppress:
+                    iou = _iou(out_data[i, j], out_data[i, k], coord_start)
+                    if iou >= iou_threshold:
+                        out_data[i, k, score_index] = -1.0
+                        out_box_indices[i, k] = -1
+
+        if return_indices:
+            # Compact valid indices to top and remap to original
+            cnt = 0
+            for j in range(num_anchors):
+                if out_box_indices[i, j] >= 0:
+                    orig_idx = out_box_indices[i, j]
+                    compacted[i, cnt] = int(indices[i, orig_idx])
+                    cnt += 1
+            valid_box_count[i, 0] = cnt
+
+    if return_indices:
+        return [compacted, valid_box_count]
+
+    if invalid_to_bottom:
+        # Rearrange valid boxes to top
+        result = np.full_like(data, -1.0)
+        for i in range(batch_size):
+            cnt = 0
+            for j in range(num_anchors):
+                if out_data[i, j, score_index] >= 0:
+                    result[i, cnt, :] = out_data[i, j, :]
+                    cnt += 1
+        return result
+
+    return out_data
diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py
index 9bdedc3535..b69e9c2aa1 100644
--- a/python/tvm/topi/vision/nms.py
+++ b/python/tvm/topi/vision/nms.py
@@ -36,37 +36,510 @@ from .nms_util import (
 )
 
 
-def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):  # 
pylint: disable=unused-argument
+def _get_valid_counts_ir(
+    data, score_threshold, id_index, score_index, valid_count, out_tensor, 
out_indices
+):
+    """IR for get_valid_counts. Filters boxes by score and compacts valid ones 
to the top."""
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+    box_data_length = data.shape[2]
+
+    with IRBuilder() as ib:
+        data = T.buffer_proxy(data)
+        valid_count = T.buffer_proxy(valid_count)
+        out_tensor = T.buffer_proxy(out_tensor)
+        out_indices = T.buffer_proxy(out_indices)
+
+        with T.parallel(0, batch_size) as i:
+            valid_count[i] = T.int32(0)
+
+            with T.serial(0, num_anchors) as j:
+                score = data[i, j, score_index]
+                if id_index < 0:
+                    is_valid = score > score_threshold
+                else:
+                    is_valid = tvm.tirx.all(score > score_threshold, data[i, 
j, id_index] >= 0)
+
+                with T.If(is_valid):
+                    with T.Then():
+                        cur = valid_count[i]
+                        with T.serial(0, box_data_length) as k:
+                            out_tensor[i, cur, k] = data[i, j, k]
+                        out_indices[i, cur] = j
+                        valid_count[i] = cur + 1
+
+            # Fill remaining slots with -1
+            with T.serial(0, num_anchors) as j:
+                with T.If(j >= valid_count[i]):
+                    with T.Then():
+                        with T.serial(0, box_data_length) as k:
+                            out_tensor[i, j, k] = tvm.tirx.Cast(data.dtype, 
T.float32(-1.0))
+                        out_indices[i, j] = T.int32(-1)
+
+        return ib.get()
+
+
+def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
     """Get valid count of bounding boxes given a score threshold.
     Also moves valid boxes to the top of input data.
+
     Parameters
     ----------
     data : tvm.te.Tensor
-        Input data. 3-D tensor with shape [batch_size, num_anchors, 6]
-        or [batch_size, num_anchors, 5].
+        Input data. 3-D tensor with shape [batch_size, num_anchors, 
elem_length].
+
     score_threshold : optional, float
         Lower limit of score for valid bounding boxes.
+
     id_index : optional, int
-        index of the class categories, -1 to disable.
+        Index of the class categories, -1 to disable.
+
     score_index: optional, int
         Index of the scores/confidence of boxes.
+
     Returns
     -------
     valid_count : tvm.te.Tensor
-        1-D tensor for valid number of boxes.
+        1-D tensor for valid number of boxes, shape [batch_size].
+
     out_tensor : tvm.te.Tensor
-        Rearranged data tensor.
-    out_indices: tvm.te.Tensor or numpy NDArray
-        Related index in input data.
+        Rearranged data tensor, shape [batch_size, num_anchors, elem_length].
+
+    out_indices: tvm.te.Tensor
+        Related index in input data, shape [batch_size, num_anchors].
     """
-    if isinstance(score_threshold, float | int):
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+    box_data_length = data.shape[2]
+
+    is_score_threshold_tensor = isinstance(score_threshold, te.Tensor)
+    if not is_score_threshold_tensor:
         score_threshold = tvm.tirx.const(score_threshold, dtype=data.dtype)
-    # id_index_const = tvm.tirx.const(id_index, "int32")  # Unused
-    # score_index_const = tvm.tirx.const(score_index, "int32")  # Unused
-    return (
-        te.compute((data.shape[0],), lambda i: data.shape[1], 
name="valid_count"),
-        data,
-        te.compute((data.shape[0], data.shape[1]), lambda i, j: j, 
name="out_indices"),
+
+    id_index_const = tvm.tirx.const(id_index, "int32")
+    score_index_const = tvm.tirx.const(score_index, "int32")
+
+    valid_count_buf = tvm.tirx.decl_buffer((batch_size,), "int32", 
"valid_count")
+    out_tensor_buf = tvm.tirx.decl_buffer(
+        (batch_size, num_anchors, box_data_length), data.dtype, "out_tensor"
+    )
+    out_indices_buf = tvm.tirx.decl_buffer(
+        (batch_size, num_anchors), "int32", "out_indices"
+    )
+
+    if is_score_threshold_tensor:
+        score_thresh_buf = tvm.tirx.decl_buffer(
+            score_threshold.shape, score_threshold.dtype, "score_threshold"
+        )
+        valid_count, out_tensor, out_indices = te.extern(
+            [(batch_size,), (batch_size, num_anchors, box_data_length), 
(batch_size, num_anchors)],
+            [data, score_threshold],
+            lambda ins, outs: _get_valid_counts_ir(
+                ins[0], ins[1], id_index_const, score_index_const,
+                outs[0], outs[1], outs[2],
+            ),
+            dtype=["int32", data.dtype, "int32"],
+            out_buffers=[valid_count_buf, out_tensor_buf, out_indices_buf],
+            in_buffers=[
+                tvm.tirx.decl_buffer(data.shape, data.dtype, "data"),
+                score_thresh_buf,
+            ],
+            name="get_valid_counts",
+            tag="get_valid_counts",
+        )
+    else:
+        # score_threshold is a TIR constant, not a tensor
+        def _ir_with_const_threshold(ins, outs):
+            return _get_valid_counts_ir(
+                ins[0], score_threshold, id_index_const, score_index_const,
+                outs[0], outs[1], outs[2],
+            )
+
+        valid_count, out_tensor, out_indices = te.extern(
+            [(batch_size,), (batch_size, num_anchors, box_data_length), 
(batch_size, num_anchors)],
+            [data],
+            _ir_with_const_threshold,
+            dtype=["int32", data.dtype, "int32"],
+            out_buffers=[valid_count_buf, out_tensor_buf, out_indices_buf],
+            in_buffers=[tvm.tirx.decl_buffer(data.shape, data.dtype, "data")],
+            name="get_valid_counts",
+            tag="get_valid_counts",
+        )
+
+    return valid_count, out_tensor, out_indices
+
+
+def _classic_nms_ir(
+    data,
+    sorted_index,
+    valid_count,
+    indices,
+    batch_size,
+    num_anchors,
+    box_data_length,
+    max_output_size,
+    iou_threshold,
+    force_suppress,
+    top_k,
+    coord_start,
+    score_index,
+    id_index,
+    return_indices,
+    out_data,
+    out_box_indices,
+    out_valid_box_count,
+):
+    """IR for classic single-class non-maximum suppression."""
+    with IRBuilder() as ib:
+        data = T.buffer_proxy(data)
+        sorted_index = T.buffer_proxy(sorted_index)
+        valid_count = T.buffer_proxy(valid_count)
+        indices = T.buffer_proxy(indices)
+        out_data = T.buffer_proxy(out_data)
+        out_box_indices = T.buffer_proxy(out_box_indices)
+        if out_valid_box_count is not None:
+            out_valid_box_count = T.buffer_proxy(out_valid_box_count)
+
+        with T.parallel(0, batch_size) as i:
+            # Step 1: Reorder data by sorted score
+            nkeep_buf = T.alloc_buffer((1,), "int32", scope="local")
+            nkeep_local = T.buffer_proxy(nkeep_buf)
+            nkeep_local[0] = valid_count[i]
+            with T.If(tvm.tirx.all(top_k > 0, top_k < nkeep_local[0])):
+                with T.Then():
+                    nkeep_local[0] = top_k
+
+            # Copy sorted boxes to output
+            with T.serial(0, num_anchors) as j:
+                with T.If(j < nkeep_local[0]):
+                    with T.Then():
+                        src_idx = sorted_index[i, j]
+                        with T.serial(0, box_data_length) as k:
+                            out_data[i, j, k] = data[i, src_idx, k]
+                        out_box_indices[i, j] = sorted_index[i, j]
+                    with T.Else():
+                        with T.serial(0, box_data_length) as k:
+                            out_data[i, j, k] = tvm.tirx.Cast(data.dtype, 
T.float32(-1.0))
+                        out_box_indices[i, j] = T.int32(-1)
+
+            # Step 2: Apply NMS - greedy suppression
+            num_valid_boxes_buf = T.alloc_buffer((1,), "int32", scope="local")
+            num_valid_boxes = T.buffer_proxy(num_valid_boxes_buf)
+            num_valid_boxes[0] = T.int32(0)
+
+            with T.serial(0, nkeep_local[0]) as j:
+                # Check if box j is still valid (score > 0) and within 
max_output_size
+                with T.If(
+                    tvm.tirx.all(
+                        out_data[i, j, score_index] > 
tvm.tirx.Cast(data.dtype, T.float32(0.0)),
+                        tvm.tirx.Select(
+                            max_output_size > 0,
+                            num_valid_boxes[0] < max_output_size,
+                            tvm.tirx.const(True),
+                        ),
+                    )
+                ):
+                    with T.Then():
+                        num_valid_boxes[0] = num_valid_boxes[0] + 1
+
+                        # Suppress overlapping boxes
+                        with T.serial(0, nkeep_local[0]) as k:
+                            with T.If(
+                                tvm.tirx.all(
+                                    k > j,
+                                    out_data[i, k, score_index]
+                                    > tvm.tirx.Cast(data.dtype, 
T.float32(0.0)),
+                                )
+                            ):
+                                with T.Then():
+                                    # Check class ID match (or force_suppress)
+                                    do_suppress = tvm.tirx.const(False)
+                                    if force_suppress:
+                                        do_suppress = tvm.tirx.const(True)
+                                    elif id_index >= 0:
+                                        do_suppress = (
+                                            out_data[i, j, id_index] == 
out_data[i, k, id_index]
+                                        )
+                                    else:
+                                        do_suppress = tvm.tirx.const(True)
+
+                                    with T.If(do_suppress):
+                                        with T.Then():
+                                            # Calculate IoU
+                                            a_l = tvm.te.min(
+                                                out_data[i, j, coord_start],
+                                                out_data[i, j, coord_start + 
2],
+                                            )
+                                            a_t = tvm.te.min(
+                                                out_data[i, j, coord_start + 
1],
+                                                out_data[i, j, coord_start + 
3],
+                                            )
+                                            a_r = tvm.te.max(
+                                                out_data[i, j, coord_start],
+                                                out_data[i, j, coord_start + 
2],
+                                            )
+                                            a_b = tvm.te.max(
+                                                out_data[i, j, coord_start + 
1],
+                                                out_data[i, j, coord_start + 
3],
+                                            )
+
+                                            b_l = tvm.te.min(
+                                                out_data[i, k, coord_start],
+                                                out_data[i, k, coord_start + 
2],
+                                            )
+                                            b_t = tvm.te.min(
+                                                out_data[i, k, coord_start + 
1],
+                                                out_data[i, k, coord_start + 
3],
+                                            )
+                                            b_r = tvm.te.max(
+                                                out_data[i, k, coord_start],
+                                                out_data[i, k, coord_start + 
2],
+                                            )
+                                            b_b = tvm.te.max(
+                                                out_data[i, k, coord_start + 
1],
+                                                out_data[i, k, coord_start + 
3],
+                                            )
+
+                                            w = tvm.te.max(
+                                                tvm.tirx.Cast(data.dtype, 
T.float32(0.0)),
+                                                tvm.te.min(a_r, b_r) - 
tvm.te.max(a_l, b_l),
+                                            )
+                                            h = tvm.te.max(
+                                                tvm.tirx.Cast(data.dtype, 
T.float32(0.0)),
+                                                tvm.te.min(a_b, b_b) - 
tvm.te.max(a_t, b_t),
+                                            )
+                                            area = h * w
+                                            u = (
+                                                (a_r - a_l) * (a_b - a_t)
+                                                + (b_r - b_l) * (b_b - b_t)
+                                                - area
+                                            )
+                                            iou = tvm.tirx.Select(
+                                                u <= tvm.tirx.Cast(data.dtype, 
T.float32(0.0)),
+                                                tvm.tirx.Cast(data.dtype, 
T.float32(0.0)),
+                                                area / u,
+                                            )
+
+                                            with T.If(iou >= iou_threshold):
+                                                with T.Then():
+                                                    out_data[i, k, 
score_index] = tvm.tirx.Cast(
+                                                        data.dtype, 
T.float32(-1.0)
+                                                    )
+                                                    out_box_indices[i, k] = 
T.int32(-1)
+
+                    with T.Else():
+                        # Box suppressed or beyond max_output_size
+                        with T.serial(0, box_data_length) as k:
+                            out_data[i, j, k] = tvm.tirx.Cast(data.dtype, 
T.float32(-1.0))
+                        out_box_indices[i, j] = T.int32(-1)
+
+            # Step 3: If return_indices, remap to original indices
+            if return_indices:
+                if out_valid_box_count is not None:
+                    # Count valid boxes and remap indices
+                    valid_idx_buf = T.alloc_buffer((1,), "int32", 
scope="local")
+                    valid_idx = T.buffer_proxy(valid_idx_buf)
+                    valid_idx[0] = T.int32(0)
+
+                    with T.serial(0, num_anchors) as j:
+                        with T.If(out_box_indices[i, j] >= 0):
+                            with T.Then():
+                                orig_idx = out_box_indices[i, j]
+                                out_box_indices[i, valid_idx[0]] = indices[i, 
orig_idx]
+                                valid_idx[0] = valid_idx[0] + 1
+
+                    out_valid_box_count[i, 0] = valid_idx[0]
+
+                    # Fill remaining with -1
+                    with T.serial(0, num_anchors) as j:
+                        with T.If(j >= valid_idx[0]):
+                            with T.Then():
+                                out_box_indices[i, j] = T.int32(-1)
+
+        return ib.get()
+
+
+def non_max_suppression(
+    data,
+    valid_count,
+    indices,
+    max_output_size=-1,
+    iou_threshold=0.5,
+    force_suppress=False,
+    top_k=-1,
+    coord_start=2,
+    score_index=1,
+    id_index=0,
+    return_indices=True,
+    invalid_to_bottom=False,
+):
+    """Non-maximum suppression operator for object detection.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        3-D tensor with shape [batch_size, num_anchors, elem_length].
+
+    valid_count : tvm.te.Tensor
+        1-D tensor for valid number of boxes, shape [batch_size].
+
+    indices : tvm.te.Tensor
+        2-D tensor with shape [batch_size, num_anchors].
+
+    max_output_size : optional, int
+        Max number of output valid boxes for each instance.
+        Return all valid boxes if the value is less than 0.
+
+    iou_threshold : optional, float
+        Non-maximum suppression IoU threshold.
+
+    force_suppress : optional, boolean
+        Whether to suppress all detections regardless of class_id. When
+        ``id_index`` is ``-1``, all valid boxes are treated as belonging to the
+        same class, so this flag has the same effect as ``True``.
+
+    top_k : optional, int
+        Keep maximum top k detections before nms, -1 for no limit.
+
+    coord_start : required, int
+        Start index of the consecutive 4 coordinates.
+
+    score_index: optional, int
+        Index of the scores/confidence of boxes.
+
+    id_index : optional, int
+        Index of the class categories, -1 to disable.
+
+    return_indices : optional, boolean
+        Whether to return box indices in input data.
+
+    invalid_to_bottom : optional, boolean
+        Whether to move all valid bounding boxes to the top.
+
+    Returns
+    -------
+    out : tvm.te.Tensor or tuple of tvm.te.Tensor
+        If return_indices is True, returns a tuple of (box_indices, 
valid_box_count).
+        Otherwise returns the modified data tensor.
+    """
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+    box_data_length = data.shape[2]
+
+    if isinstance(max_output_size, int):
+        max_output_size = tvm.tirx.const(max_output_size, dtype="int32")
+    if isinstance(iou_threshold, (float, int)):
+        iou_threshold = tvm.tirx.const(iou_threshold, dtype=data.dtype)
+
+    # Sort by score
+    score_shape = (batch_size, num_anchors)
+    score_tensor = te.compute(
+        score_shape, lambda i, j: data[i, j, score_index], name="score_tensor"
+    )
+    sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, 
is_ascend=False)
+
+    data_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "data")
+    sort_buf = tvm.tirx.decl_buffer(sort_tensor.shape, sort_tensor.dtype, 
"sorted_index")
+    valid_count_buf = tvm.tirx.decl_buffer(valid_count.shape, 
valid_count.dtype, "valid_count")
+    indices_buf = tvm.tirx.decl_buffer(indices.shape, indices.dtype, "indices")
+
+    out_data_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "out_data")
+    out_box_indices_buf = tvm.tirx.decl_buffer(
+        (batch_size, num_anchors), "int32", "out_box_indices"
+    )
+
+    if return_indices:
+        out_valid_box_count_buf = tvm.tirx.decl_buffer(
+            (batch_size, 1), "int32", "out_valid_box_count"
+        )
+
+        out_data, out_box_indices, out_valid_box_count = te.extern(
+            [data.shape, (batch_size, num_anchors), (batch_size, 1)],
+            [data, sort_tensor, valid_count, indices],
+            lambda ins, outs: _classic_nms_ir(
+                ins[0], ins[1], ins[2], ins[3],
+                batch_size, num_anchors, box_data_length,
+                max_output_size, iou_threshold,
+                force_suppress, top_k,
+                coord_start, score_index, id_index,
+                return_indices,
+                outs[0], outs[1], outs[2],
+            ),
+            dtype=[data.dtype, "int32", "int32"],
+            out_buffers=[out_data_buf, out_box_indices_buf, 
out_valid_box_count_buf],
+            in_buffers=[data_buf, sort_buf, valid_count_buf, indices_buf],
+            name="non_max_suppression",
+            tag="non_max_suppression",
+        )
+        return [out_box_indices, out_valid_box_count]
+
+    out_data, out_box_indices = te.extern(
+        [data.shape, (batch_size, num_anchors)],
+        [data, sort_tensor, valid_count, indices],
+        lambda ins, outs: _classic_nms_ir(
+            ins[0], ins[1], ins[2], ins[3],
+            batch_size, num_anchors, box_data_length,
+            max_output_size, iou_threshold,
+            force_suppress, top_k,
+            coord_start, score_index, id_index,
+            return_indices,
+            outs[0], outs[1], None,
+        ),
+        dtype=[data.dtype, "int32"],
+        out_buffers=[out_data_buf, out_box_indices_buf],
+        in_buffers=[data_buf, sort_buf, valid_count_buf, indices_buf],
+        name="non_max_suppression",
+        tag="non_max_suppression",
+    )
+
+    if invalid_to_bottom:
+        # Rearrange to move valid boxes to top
+        return _rearrange_out(out_data, batch_size, num_anchors, 
box_data_length, score_index)
+
+    return out_data
+
+
+def _rearrange_out(data, batch_size, num_anchors, box_data_length, 
score_index):
+    """Move valid boxes (score >= 0) to the top of output."""
+    out_buf = tvm.tirx.decl_buffer(
+        (batch_size, num_anchors, box_data_length), data.dtype, "rearranged"
+    )
+
+    def _rearrange_ir(ins, outs):
+        with IRBuilder() as ib:
+            data = T.buffer_proxy(ins[0])
+            out = T.buffer_proxy(outs[0])
+
+            with T.parallel(0, batch_size) as i:
+                valid_idx_buf = T.alloc_buffer((1,), "int32", scope="local")
+                valid_idx = T.buffer_proxy(valid_idx_buf)
+                valid_idx[0] = T.int32(0)
+
+                with T.serial(0, num_anchors) as j:
+                    with T.If(
+                        data[i, j, score_index] >= tvm.tirx.Cast(data.dtype, 
T.float32(0.0))
+                    ):
+                        with T.Then():
+                            with T.serial(0, box_data_length) as k:
+                                out[i, valid_idx[0], k] = data[i, j, k]
+                            valid_idx[0] = valid_idx[0] + 1
+
+                with T.serial(0, num_anchors) as j:
+                    with T.If(j >= valid_idx[0]):
+                        with T.Then():
+                            with T.serial(0, box_data_length) as k:
+                                out[i, j, k] = tvm.tirx.Cast(data.dtype, 
T.float32(-1.0))
+
+            return ib.get()
+
+    return te.extern(
+        [(batch_size, num_anchors, box_data_length)],
+        [data],
+        _rearrange_ir,
+        dtype=[data.dtype],
+        out_buffers=[out_buf],
+        name="rearrange_out",
+        tag="rearrange_out",
     )
 
 
diff --git a/python/tvm/topi/vision/nms_util.py 
b/python/tvm/topi/vision/nms_util.py
index ae17168970..a4b4c78363 100644
--- a/python/tvm/topi/vision/nms_util.py
+++ b/python/tvm/topi/vision/nms_util.py
@@ -303,8 +303,8 @@ def _all_class_nms_ir(
         if selected_scores is not None:
             selected_scores = T.buffer_proxy(selected_scores)
 
-        if isinstance(iou_threshold, float):
-            iou_threshold = tvm.tirx.FloatImm("float32", iou_threshold)
+        if isinstance(iou_threshold, (float, int)):
+            iou_threshold = tvm.tirx.FloatImm("float32", float(iou_threshold))
         elif isinstance(iou_threshold, te.Tensor):
             if len(iou_threshold.shape) == 0:
                 iou_threshold = iou_threshold()
diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc
index 294cd40c45..97508d7211 100644
--- a/src/relax/op/vision/nms.cc
+++ b/src/relax/op/vision/nms.cc
@@ -18,6 +18,7 @@
  */
 #include "nms.h"
 
+#include <tvm/arith/analyzer.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/ffi/string.h>
 #include <tvm/ir/attrs.h>
@@ -33,7 +34,11 @@
 namespace tvm {
 namespace relax {
 
-TVM_FFI_STATIC_INIT_BLOCK() { 
AllClassNonMaximumSuppressionAttrs::RegisterReflection(); }
+TVM_FFI_STATIC_INIT_BLOCK() {
+  AllClassNonMaximumSuppressionAttrs::RegisterReflection();
+  GetValidCountsAttrs::RegisterReflection();
+  NonMaximumSuppressionAttrs::RegisterReflection();
+}
 
 /* relax.vision.all_class_non_max_suppression */
 
@@ -110,5 +115,242 @@ 
TVM_REGISTER_OP("relax.vision.all_class_non_max_suppression")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAllClassNMS)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.vision.get_valid_counts */
+
+Expr get_valid_counts(Expr data, double score_threshold, int id_index, int 
score_index) {
+  auto attrs = tvm::ffi::make_object<GetValidCountsAttrs>();
+  attrs->score_threshold = score_threshold;
+  attrs->id_index = id_index;
+  attrs->score_index = score_index;
+
+  static const Op& op = Op::Get("relax.vision.get_valid_counts");
+  return Call(op, {std::move(data)}, Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def("relax.op.vision.get_valid_counts", get_valid_counts);
+}
+
+StructInfo InferStructInfoGetValidCounts(const Call& call, const BlockBuilder& 
ctx) {
+  if (call->args.size() != 1) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "get_valid_counts expects 1 argument, got " << 
call->args.size());
+  }
+
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  if (data_sinfo == nullptr) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "get_valid_counts expects input data to be a Tensor.");
+  }
+  if (data_sinfo->ndim != -1 && data_sinfo->ndim != 3) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "get_valid_counts expects 3-D input, got ndim " << 
data_sinfo->ndim);
+  }
+
+  const auto* attrs = call->attrs.as<GetValidCountsAttrs>();
+  TVM_FFI_ICHECK(attrs != nullptr) << "Invalid get_valid_counts attrs";
+  auto vdev = data_sinfo->vdevice;
+  const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+  if (data_shape == nullptr) {
+    tvm::ffi::Array<StructInfo> fields = {
+        TensorStructInfo(DataType::Int(32), /*ndim=*/1, vdev),
+        TensorStructInfo(data_sinfo->dtype, /*ndim=*/3, vdev),
+        TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev)};
+    return TupleStructInfo(fields);
+  }
+
+  auto batch = data_shape->values[0];
+  auto num_anchors = data_shape->values[1];
+  auto elem_length = data_shape->values[2];
+  const auto* elem_length_imm = elem_length.as<IntImmNode>();
+  if (elem_length_imm != nullptr) {
+    if (attrs->score_index < 0 || attrs->score_index >= 
elem_length_imm->value) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "get_valid_counts expects score_index to be in range 
[0, "
+                       << elem_length_imm->value << "), but got " << 
attrs->score_index);
+    }
+    if (attrs->id_index < -1 || attrs->id_index >= elem_length_imm->value) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "get_valid_counts expects id_index to be in range 
[-1, "
+                       << elem_length_imm->value << "), but got " << 
attrs->id_index);
+    }
+  }
+
+  tvm::ffi::Array<StructInfo> fields = {
+      TensorStructInfo(ShapeExpr({batch}), DataType::Int(32), vdev),
+      TensorStructInfo(ShapeExpr({batch, num_anchors, elem_length}), 
data_sinfo->dtype, vdev),
+      TensorStructInfo(ShapeExpr({batch, num_anchors}), DataType::Int(32), 
vdev)};
+  return TupleStructInfo(fields);
+}
+
+TVM_REGISTER_OP("relax.vision.get_valid_counts")
+    .set_attrs_type<GetValidCountsAttrs>()
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor",
+                  "Input data, 3-D tensor [batch_size, num_anchors, 
elem_length].")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoGetValidCounts)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+/* relax.vision.non_max_suppression */
+
+Expr non_max_suppression(Expr data, Expr valid_count, Expr indices, int 
max_output_size,
+                         double iou_threshold, bool force_suppress, int top_k, 
int coord_start,
+                         int score_index, int id_index, bool return_indices,
+                         bool invalid_to_bottom) {
+  auto attrs = tvm::ffi::make_object<NonMaximumSuppressionAttrs>();
+  attrs->max_output_size = max_output_size;
+  attrs->iou_threshold = iou_threshold;
+  attrs->force_suppress = force_suppress;
+  attrs->top_k = top_k;
+  attrs->coord_start = coord_start;
+  attrs->score_index = score_index;
+  attrs->id_index = id_index;
+  attrs->return_indices = return_indices;
+  attrs->invalid_to_bottom = invalid_to_bottom;
+
+  static const Op& op = Op::Get("relax.vision.non_max_suppression");
+  return Call(op, {std::move(data), std::move(valid_count), 
std::move(indices)}, Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def("relax.op.vision.non_max_suppression", 
non_max_suppression);
+}
+
+StructInfo InferStructInfoNMS(const Call& call, const BlockBuilder& ctx) {
+  if (call->args.size() != 3) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "non_max_suppression expects 3 arguments, got " << 
call->args.size());
+  }
+
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* valid_count_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+  const auto* indices_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
+  if (data_sinfo == nullptr) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "non_max_suppression expects input data to be a 
Tensor.");
+  }
+  if (valid_count_sinfo == nullptr) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "non_max_suppression expects valid_count to be a 
Tensor.");
+  }
+  if (indices_sinfo == nullptr) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "non_max_suppression expects indices to be a Tensor.");
+  }
+  if (data_sinfo->ndim != -1 && data_sinfo->ndim != 3) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "non_max_suppression expects 3-D input, got ndim " << 
data_sinfo->ndim);
+  }
+  if (valid_count_sinfo->ndim != -1 && valid_count_sinfo->ndim != 1) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "non_max_suppression expects valid_count to be 1-D, 
got ndim "
+                     << valid_count_sinfo->ndim);
+  }
+  if (indices_sinfo->ndim != -1 && indices_sinfo->ndim != 2) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "non_max_suppression expects indices to be 2-D, got 
ndim "
+                     << indices_sinfo->ndim);
+  }
+  if (!valid_count_sinfo->IsUnknownDtype() && valid_count_sinfo->dtype != 
DataType::Int(32)) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "non_max_suppression expects valid_count to have dtype 
int32, got "
+                     << valid_count_sinfo->dtype);
+  }
+  if (!indices_sinfo->IsUnknownDtype() && indices_sinfo->dtype != 
DataType::Int(32)) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "non_max_suppression expects indices to have dtype 
int32, got "
+                     << indices_sinfo->dtype);
+  }
+
+  const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+  const auto* valid_count_shape = valid_count_sinfo->shape.as<ShapeExprNode>();
+  const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
+  if (data_shape != nullptr) {
+    arith::Analyzer* analyzer = ctx->GetAnalyzer();
+    PrimExpr batch = data_shape->values[0];
+    PrimExpr num_anchors = data_shape->values[1];
+    if (valid_count_shape != nullptr &&
+        !analyzer->CanProveEqual(valid_count_shape->values[0], batch)) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "non_max_suppression expects valid_count to have 
shape [batch_size]. "
+                          "However, the given data tensor has batch size `"
+                       << batch << "` and the given valid_count tensor has 
shape "
+                       << valid_count_sinfo->shape);
+    }
+    if (indices_shape != nullptr) {
+      if (!analyzer->CanProveEqual(indices_shape->values[0], batch) ||
+          !analyzer->CanProveEqual(indices_shape->values[1], num_anchors)) {
+        ctx->ReportFatal(
+            Diagnostic::Error(call)
+            << "non_max_suppression expects indices to have shape [batch_size, 
num_anchors]. "
+               "However, the given data tensor has shape "
+            << data_sinfo->shape << " and the given indices tensor has shape "
+            << indices_sinfo->shape);
+      }
+    }
+  }
+
+  const auto* attrs = call->attrs.as<NonMaximumSuppressionAttrs>();
+  TVM_FFI_ICHECK(attrs != nullptr) << "Invalid non_max_suppression attrs";
+  auto vdev = data_sinfo->vdevice;
+  if (data_shape != nullptr) {
+    const auto* elem_length_imm = data_shape->values[2].as<IntImmNode>();
+    if (elem_length_imm != nullptr) {
+      int64_t elem_length = elem_length_imm->value;
+      if (attrs->score_index < 0 || attrs->score_index >= elem_length) {
+        ctx->ReportFatal(Diagnostic::Error(call)
+                         << "non_max_suppression expects score_index to be in 
range [0, "
+                         << elem_length << "), but got " << 
attrs->score_index);
+      }
+      if (attrs->coord_start < 0 || attrs->coord_start + 3 >= elem_length) {
+        ctx->ReportFatal(Diagnostic::Error(call)
+                         << "non_max_suppression expects coord_start to 
reference four "
+                            "consecutive box coordinates within elem_length "
+                         << elem_length << ", but got " << attrs->coord_start);
+      }
+      if (attrs->id_index < -1 || attrs->id_index >= elem_length) {
+        ctx->ReportFatal(Diagnostic::Error(call)
+                         << "non_max_suppression expects id_index to be in 
range [-1, "
+                         << elem_length << "), but got " << attrs->id_index);
+      }
+    }
+  }
+
+  if (attrs->return_indices) {
+    // Returns (box_indices[batch, num_anchors], valid_box_count[batch, 1])
+    if (data_shape == nullptr) {
+      tvm::ffi::Array<StructInfo> fields = {
+          TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev),
+          TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev)};
+      return TupleStructInfo(fields);
+    }
+    auto batch = data_shape->values[0];
+    auto num_anchors = data_shape->values[1];
+    tvm::ffi::Array<StructInfo> fields = {
+        TensorStructInfo(ShapeExpr({batch, num_anchors}), DataType::Int(32), 
vdev),
+        TensorStructInfo(ShapeExpr({batch, IntImm(DataType::Int(64), 1)}), 
DataType::Int(32),
+                         vdev)};
+    return TupleStructInfo(fields);
+  }
+
+  // Returns modified data tensor with the same shape as input.
+  if (const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>()) {
+    return TensorStructInfo(ffi::GetRef<ShapeExpr>(data_shape), 
data_sinfo->dtype, vdev);
+  }
+  return TensorStructInfo(data_sinfo->dtype, /*ndim=*/3, vdev);
+}
+
+TVM_REGISTER_OP("relax.vision.non_max_suppression")
+    .set_attrs_type<NonMaximumSuppressionAttrs>()
+    .set_num_inputs(3)
+    .add_argument("data", "Tensor",
+                  "Input data, 3-D tensor [batch_size, num_anchors, 
elem_length].")
+    .add_argument("valid_count", "Tensor", "1-D tensor for valid number of 
boxes.")
+    .add_argument("indices", "Tensor", "2-D tensor with shape [batch_size, 
num_anchors].")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoNMS)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/vision/nms.h b/src/relax/op/vision/nms.h
index c86bf98c94..3fbd2609e2 100644
--- a/src/relax/op/vision/nms.h
+++ b/src/relax/op/vision/nms.h
@@ -38,6 +38,15 @@ Expr all_class_non_max_suppression(Expr boxes, Expr scores, 
Expr max_output_boxe
                                    Expr iou_threshold, Expr score_threshold,
                                    ffi::String output_format);
 
+/*! \brief Get valid count of bounding boxes given a score threshold. */
+Expr get_valid_counts(Expr data, double score_threshold, int id_index, int 
score_index);
+
+/*! \brief Non-maximum suppression for object detection. */
+Expr non_max_suppression(Expr data, Expr valid_count, Expr indices, int 
max_output_size,
+                         double iou_threshold, bool force_suppress, int top_k, 
int coord_start,
+                         int score_index, int id_index, bool return_indices,
+                         bool invalid_to_bottom);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_op_vision.py 
b/tests/python/relax/test_op_vision.py
index cded9f5f29..6d04a796ca 100644
--- a/tests/python/relax/test_op_vision.py
+++ b/tests/python/relax/test_op_vision.py
@@ -20,6 +20,7 @@ import pytest
 
 import tvm
 import tvm.testing
+import tvm.topi.testing
 from tvm import TVMError, relax, tirx
 from tvm.ir import Op
 from tvm.relax.transform import LegalizeOps
@@ -31,6 +32,23 @@ def _check_inference(bb: relax.BlockBuilder, call: 
relax.Call, expected_sinfo: r
     tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
 
 
+def _assert_relax_op_legalized(mod: tvm.IRModule, op_name: str) -> None:
+    seen_call_tir = False
+    seen_original_op = False
+
+    def _visit(expr):
+        nonlocal seen_call_tir, seen_original_op
+        if isinstance(expr, relax.Call) and isinstance(expr.op, tvm.ir.Op):
+            if expr.op.name == "relax.call_tir":
+                seen_call_tir = True
+            if expr.op.name == op_name:
+                seen_original_op = True
+
+    relax.analysis.post_order_visit(mod["main"].body, _visit)
+    assert seen_call_tir
+    assert not seen_original_op
+
+
 def test_roi_align_op_correctness():
     x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
     rois = relax.Var("rois", R.Tensor((4, 5), "float32"))
@@ -198,6 +216,840 @@ def test_roi_align_legalize_sample_ratio_zero():
     )
 
 
+def test_get_valid_counts_op_correctness():
+    data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+    assert relax.op.vision.get_valid_counts(data, 0.5).op == 
Op.get("relax.vision.get_valid_counts")
+
+
+def test_get_valid_counts_infer_struct_info():
+    bb = relax.BlockBuilder()
+    data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+    _check_inference(
+        bb,
+        relax.op.vision.get_valid_counts(data, score_threshold=0.5, 
id_index=0, score_index=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((2,), "int32"),
+                relax.TensorStructInfo((2, 10, 6), "float32"),
+                relax.TensorStructInfo((2, 10), "int32"),
+            ]
+        ),
+    )
+
+
+def test_get_valid_counts_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    n = tirx.Var("n", "int64")
+    m = tirx.Var("m", "int64")
+    k = tirx.Var("k", "int64")
+    data = relax.Var("data", R.Tensor((n, m, k), "float32"))
+    _check_inference(
+        bb,
+        relax.op.vision.get_valid_counts(data, score_threshold=0.0),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((n,), "int32"),
+                relax.TensorStructInfo((n, m, k), "float32"),
+                relax.TensorStructInfo((n, m), "int32"),
+            ]
+        ),
+    )
+
+
+def test_get_valid_counts_wrong_ndim():
+    bb = relax.BlockBuilder()
+    data = relax.Var("data", R.Tensor((10, 6), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.get_valid_counts(data))
+
+
+def test_get_valid_counts_invalid_indices():
+    bb = relax.BlockBuilder()
+    data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.get_valid_counts(data, score_index=6))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.get_valid_counts(data, id_index=6))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.get_valid_counts(data, id_index=-2))
+
+
+def test_nms_op_correctness():
+    data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+    valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+    indices = relax.Var("indices", R.Tensor((2, 10), "int32"))
+    assert relax.op.vision.non_max_suppression(
+        data, valid_count, indices
+    ).op == Op.get("relax.vision.non_max_suppression")
+
+
+def test_nms_infer_struct_info_return_indices():
+    bb = relax.BlockBuilder()
+    data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+    valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+    indices = relax.Var("indices", R.Tensor((2, 10), "int32"))
+    _check_inference(
+        bb,
+        relax.op.vision.non_max_suppression(
+            data, valid_count, indices, return_indices=True
+        ),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((2, 10), "int32"),
+                relax.TensorStructInfo((2, 1), "int32"),
+            ]
+        ),
+    )
+
+
+def test_nms_infer_struct_info_return_data():
+    bb = relax.BlockBuilder()
+    data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+    valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+    indices = relax.Var("indices", R.Tensor((2, 10), "int32"))
+    _check_inference(
+        bb,
+        relax.op.vision.non_max_suppression(
+            data, valid_count, indices, return_indices=False
+        ),
+        relax.TensorStructInfo((2, 10, 6), "float32"),
+    )
+
+
+def test_nms_infer_struct_info_return_data_shape_var():
+    bb = relax.BlockBuilder()
+    batch_size = tirx.Var("batch_size", "int64")
+    num_anchors = tirx.Var("num_anchors", "int64")
+    elem_length = tirx.Var("elem_length", "int64")
+    data = relax.Var("data", R.Tensor((batch_size, num_anchors, elem_length), 
"float32"))
+    valid_count = relax.Var("valid_count", R.Tensor((batch_size,), "int32"))
+    indices = relax.Var("indices", R.Tensor((batch_size, num_anchors), 
"int32"))
+    _check_inference(
+        bb,
+        relax.op.vision.non_max_suppression(
+            data, valid_count, indices, return_indices=False
+        ),
+        relax.TensorStructInfo((batch_size, num_anchors, elem_length), 
"float32"),
+    )
+
+
+def test_nms_wrong_ndim():
+    bb = relax.BlockBuilder()
+    data = relax.Var("data", R.Tensor((10, 6), "float32"))
+    valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+    indices = relax.Var("indices", R.Tensor((2, 10), "int32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, 
indices))
+
+
+def test_nms_wrong_valid_count_ndim():
+    bb = relax.BlockBuilder()
+    data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+    valid_count = relax.Var("valid_count", R.Tensor((2, 1), "int32"))
+    indices = relax.Var("indices", R.Tensor((2, 10), "int32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, 
indices))
+
+
+def test_nms_wrong_indices_ndim():
+    bb = relax.BlockBuilder()
+    data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+    valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+    indices = relax.Var("indices", R.Tensor((20,), "int32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, 
indices))
+
+
+def test_nms_wrong_aux_input_dtype():
+    bb = relax.BlockBuilder()
+    data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+    valid_count_i64 = relax.Var("valid_count_i64", R.Tensor((2,), "int64"))
+    valid_count_i32 = relax.Var("valid_count_i32", R.Tensor((2,), "int32"))
+    indices_i64 = relax.Var("indices_i64", R.Tensor((2, 10), "int64"))
+    indices_i32 = relax.Var("indices_i32", R.Tensor((2, 10), "int32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.non_max_suppression(data, 
valid_count_i64, indices_i32))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.non_max_suppression(data, 
valid_count_i32, indices_i64))
+
+
+def test_nms_wrong_aux_input_shape():
+    bb = relax.BlockBuilder()
+    data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+    valid_count_bad_batch = relax.Var("valid_count_bad_batch", R.Tensor((3,), 
"int32"))
+    valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+    indices_bad_batch = relax.Var("indices_bad_batch", R.Tensor((3, 10), 
"int32"))
+    indices_bad_anchors = relax.Var("indices_bad_anchors", R.Tensor((2, 9), 
"int32"))
+    with pytest.raises(TVMError):
+        bb.normalize(
+            relax.op.vision.non_max_suppression(
+                data, valid_count_bad_batch, indices_bad_anchors
+            )
+        )
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, 
indices_bad_batch))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, 
indices_bad_anchors))
+
+
+def test_nms_invalid_indices():
+    bb = relax.BlockBuilder()
+    data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+    valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+    indices = relax.Var("indices", R.Tensor((2, 10), "int32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, 
indices, score_index=6))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, 
indices, id_index=6))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, 
indices, id_index=-2))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, 
indices, coord_start=3))
+
+
+def test_get_valid_counts_legalize():
+    @tvm.script.ir_module
+    class GVC:
+        @R.function
+        def main(
+            data: R.Tensor((1, 5, 6), "float32"),
+        ) -> R.Tuple(
+            R.Tensor((1,), "int32"),
+            R.Tensor((1, 5, 6), "float32"),
+            R.Tensor((1, 5), "int32"),
+        ):
+            gv = R.vision.get_valid_counts(data, score_threshold=0.5, 
id_index=0, score_index=1)
+            return gv
+
+    mod = LegalizeOps()(GVC)
+    _assert_relax_op_legalized(mod, "relax.vision.get_valid_counts")
+    tvm.ir.assert_structural_equal(
+        mod["main"].ret_struct_info,
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((1,), "int32"),
+                relax.TensorStructInfo((1, 5, 6), "float32"),
+                relax.TensorStructInfo((1, 5), "int32"),
+            ]
+        ),
+    )
+
+
+def test_nms_legalize():
+    @tvm.script.ir_module
+    class NMS:
+        @R.function
+        def main(
+            data: R.Tensor((1, 5, 6), "float32"),
+            valid_count: R.Tensor((1,), "int32"),
+            indices: R.Tensor((1, 5), "int32"),
+        ) -> R.Tuple(R.Tensor((1, 5), "int32"), R.Tensor((1, 1), "int32")):
+            gv = R.vision.non_max_suppression(
+                data,
+                valid_count,
+                indices,
+                max_output_size=-1,
+                iou_threshold=0.5,
+                force_suppress=False,
+                top_k=-1,
+                coord_start=2,
+                score_index=1,
+                id_index=0,
+                return_indices=True,
+                invalid_to_bottom=False,
+            )
+            return gv
+
+    mod = LegalizeOps()(NMS)
+    _assert_relax_op_legalized(mod, "relax.vision.non_max_suppression")
+    tvm.ir.assert_structural_equal(
+        mod["main"].ret_struct_info,
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((1, 5), "int32"),
+                relax.TensorStructInfo((1, 1), "int32"),
+            ]
+        ),
+    )
+
+
+def test_nms_legalize_return_data():
+    @tvm.script.ir_module
+    class NMS:
+        @R.function
+        def main(
+            data: R.Tensor((1, 5, 6), "float32"),
+            valid_count: R.Tensor((1,), "int32"),
+            indices: R.Tensor((1, 5), "int32"),
+        ) -> R.Tensor((1, 5, 6), "float32"):
+            gv = R.vision.non_max_suppression(
+                data,
+                valid_count,
+                indices,
+                max_output_size=-1,
+                iou_threshold=0.5,
+                force_suppress=False,
+                top_k=-1,
+                coord_start=2,
+                score_index=1,
+                id_index=0,
+                return_indices=False,
+                invalid_to_bottom=True,
+            )
+            return gv
+
+    mod = LegalizeOps()(NMS)
+    _assert_relax_op_legalized(mod, "relax.vision.non_max_suppression")
+    tvm.ir.assert_structural_equal(
+        mod["main"].ret_struct_info,
+        relax.TensorStructInfo((1, 5, 6), "float32"),
+    )
+
+
[email protected]_llvm
+def test_get_valid_counts_e2e():
+    """Run get_valid_counts through legalization and compare with the numpy 
reference."""
+
+    @tvm.script.ir_module
+    class GVCModule:
+        @R.function
+        def main(
+            data: R.Tensor((2, 5, 6), "float32"),
+        ) -> R.Tuple(
+            R.Tensor((2,), "int32"),
+            R.Tensor((2, 5, 6), "float32"),
+            R.Tensor((2, 5), "int32"),
+        ):
+            return R.vision.get_valid_counts(data, score_threshold=0.5, 
id_index=0, score_index=1)
+
+    data_np = np.array(
+        [
+            [
+                [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+                [1.0, 0.30, 0.0, 0.0, 1.0, 1.0],
+                [-1.0, 0.90, 0.0, 0.0, 1.0, 1.0],
+                [2.0, 0.75, 2.0, 2.0, 3.0, 3.0],
+                [1.0, 0.10, 4.0, 4.0, 5.0, 5.0],
+            ],
+            [
+                [0.0, 0.55, 0.0, 0.0, 1.0, 1.0],
+                [1.0, 0.80, 1.0, 1.0, 2.0, 2.0],
+                [2.0, 0.40, 2.0, 2.0, 3.0, 3.0],
+                [3.0, 0.60, 3.0, 3.0, 4.0, 4.0],
+                [-1.0, 0.95, 5.0, 5.0, 6.0, 6.0],
+            ],
+        ],
+        dtype="float32",
+    )
+    ref_valid_count, ref_out_data, ref_out_indices = 
tvm.topi.testing.get_valid_counts_python(
+        data_np, score_threshold=0.5, id_index=0, score_index=1
+    )
+
+    mod = LegalizeOps()(GVCModule)
+    exe = tvm.compile(mod, target="llvm")
+    vm = relax.VirtualMachine(exe, tvm.cpu())
+    result = vm["main"](tvm.runtime.tensor(data_np, tvm.cpu()))
+
+    tvm.testing.assert_allclose(result[0].numpy(), ref_valid_count)
+    tvm.testing.assert_allclose(result[1].numpy(), ref_out_data)
+    tvm.testing.assert_allclose(result[2].numpy(), ref_out_indices)
+
+
+def _prepare_nms_inputs(raw_data: np.ndarray):
+    """Prepare classic NMS inputs with the numpy get_valid_counts reference."""
+
+    return tvm.topi.testing.get_valid_counts_python(
+        raw_data, score_threshold=0.5, id_index=0, score_index=1
+    )
+
+
+def _run_nms_e2e(
+    data_np: np.ndarray,
+    valid_count_np: np.ndarray,
+    indices_np: np.ndarray,
+    *,
+    max_output_size: int = -1,
+    iou_threshold: float = 0.5,
+    force_suppress: bool = False,
+    top_k: int = -1,
+    coord_start: int = 2,
+    score_index: int = 1,
+    id_index: int = 0,
+    return_indices: bool = True,
+    invalid_to_bottom: bool = False,
+):
+    """Run classic NMS through legalization and VM execution."""
+
+    data_shape = tuple(int(dim) for dim in data_np.shape)
+    valid_count_shape = tuple(int(dim) for dim in valid_count_np.shape)
+    indices_shape = tuple(int(dim) for dim in indices_np.shape)
+    data = relax.Var("data", relax.TensorStructInfo(data_shape, "float32"))
+    valid_count = relax.Var("valid_count", 
relax.TensorStructInfo(valid_count_shape, "int32"))
+    indices = relax.Var("indices", relax.TensorStructInfo(indices_shape, 
"int32"))
+
+    bb = relax.BlockBuilder()
+    with bb.function("main", (data, valid_count, indices)):
+        result = bb.emit(
+            relax.op.vision.non_max_suppression(
+                data,
+                valid_count,
+                indices,
+                max_output_size=max_output_size,
+                iou_threshold=iou_threshold,
+                force_suppress=force_suppress,
+                top_k=top_k,
+                coord_start=coord_start,
+                score_index=score_index,
+                id_index=id_index,
+                return_indices=return_indices,
+                invalid_to_bottom=invalid_to_bottom,
+            )
+        )
+        bb.emit_func_output(result)
+
+    mod = LegalizeOps()(bb.get())
+    exe = tvm.compile(mod, target="llvm")
+    vm = relax.VirtualMachine(exe, tvm.cpu())
+    return vm["main"](
+        tvm.runtime.tensor(data_np, tvm.cpu()),
+        tvm.runtime.tensor(valid_count_np, tvm.cpu()),
+        tvm.runtime.tensor(indices_np, tvm.cpu()),
+    )
+
+
[email protected]_llvm
+def test_nms_e2e_return_indices():
+    """Run classic NMS through legalization and compare with the numpy 
reference."""
+
+    raw_data = np.array(
+        [
+            [
+                [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.90, 0.05, 0.05, 1.05, 1.05],
+                [1.0, 0.85, 0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.60, 2.0, 2.0, 3.0, 3.0],
+                [-1.0, 0.99, 0.0, 0.0, 1.0, 1.0],
+            ]
+        ],
+        dtype="float32",
+    )
+    valid_count_np, filtered_data_np, filtered_indices_np = 
_prepare_nms_inputs(raw_data)
+    ref_indices, ref_valid_box_count = 
tvm.topi.testing.non_max_suppression_python(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        max_output_size=-1,
+        iou_threshold=0.5,
+        force_suppress=False,
+        top_k=-1,
+        coord_start=2,
+        score_index=1,
+        id_index=0,
+        return_indices=True,
+        invalid_to_bottom=False,
+    )
+    result = _run_nms_e2e(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        return_indices=True,
+        invalid_to_bottom=False,
+    )
+
+    tvm.testing.assert_allclose(result[0].numpy(), ref_indices)
+    tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count)
+
+
[email protected]_llvm
+def test_nms_e2e_return_indices_with_invalid_to_bottom():
+    """Validate that invalid_to_bottom is a no-op when returning indices."""
+
+    raw_data = np.array(
+        [
+            [
+                [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.90, 0.05, 0.05, 1.05, 1.05],
+                [1.0, 0.85, 0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.60, 2.0, 2.0, 3.0, 3.0],
+                [-1.0, 0.99, 0.0, 0.0, 1.0, 1.0],
+            ]
+        ],
+        dtype="float32",
+    )
+    valid_count_np, filtered_data_np, filtered_indices_np = 
_prepare_nms_inputs(raw_data)
+    ref_indices, ref_valid_box_count = 
tvm.topi.testing.non_max_suppression_python(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        max_output_size=-1,
+        iou_threshold=0.5,
+        force_suppress=False,
+        top_k=-1,
+        coord_start=2,
+        score_index=1,
+        id_index=0,
+        return_indices=True,
+        invalid_to_bottom=False,
+    )
+    result = _run_nms_e2e(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        return_indices=True,
+        invalid_to_bottom=True,
+    )
+
+    tvm.testing.assert_allclose(result[0].numpy(), ref_indices)
+    tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count)
+
+
[email protected]_llvm
+def test_nms_e2e_top_k():
+    """Validate that classic NMS honors top_k before suppression."""
+
+    raw_data = np.array(
+        [
+            [
+                [-1.0, 0.99, 9.0, 9.0, 10.0, 10.0],
+                [0.0, 0.97, 0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.96, 2.0, 2.0, 3.0, 3.0],
+                [0.0, 0.95, 4.0, 4.0, 5.0, 5.0],
+                [1.0, 0.94, 6.0, 6.0, 7.0, 7.0],
+                [0.0, 0.20, 8.0, 8.0, 9.0, 9.0],
+            ]
+        ],
+        dtype="float32",
+    )
+    valid_count_np, filtered_data_np, filtered_indices_np = 
_prepare_nms_inputs(raw_data)
+    ref_indices, ref_valid_box_count = 
tvm.topi.testing.non_max_suppression_python(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        max_output_size=-1,
+        iou_threshold=0.5,
+        force_suppress=False,
+        top_k=2,
+        coord_start=2,
+        score_index=1,
+        id_index=0,
+        return_indices=True,
+        invalid_to_bottom=False,
+    )
+    result = _run_nms_e2e(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        top_k=2,
+        return_indices=True,
+        invalid_to_bottom=False,
+    )
+
+    tvm.testing.assert_allclose(result[0].numpy(), ref_indices)
+    tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count)
+    np.testing.assert_array_equal(ref_indices, np.array([[1, 2, -1, -1, -1, 
-1]], dtype="int32"))
+    np.testing.assert_array_equal(ref_valid_box_count, np.array([[2]], 
dtype="int32"))
+
+
[email protected]_llvm
+def test_nms_e2e_force_suppress():
+    """Validate that force_suppress ignores class ids when suppressing 
overlaps."""
+
+    raw_data = np.array(
+        [
+            [
+                [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+                [1.0, 0.90, 0.05, 0.05, 1.05, 1.05],
+                [1.0, 0.80, 2.0, 2.0, 3.0, 3.0],
+                [-1.0, 0.99, 8.0, 8.0, 9.0, 9.0],
+            ]
+        ],
+        dtype="float32",
+    )
+    valid_count_np, filtered_data_np, filtered_indices_np = 
_prepare_nms_inputs(raw_data)
+    ref_indices, ref_valid_box_count = 
tvm.topi.testing.non_max_suppression_python(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        max_output_size=-1,
+        iou_threshold=0.5,
+        force_suppress=True,
+        top_k=-1,
+        coord_start=2,
+        score_index=1,
+        id_index=0,
+        return_indices=True,
+        invalid_to_bottom=False,
+    )
+    result = _run_nms_e2e(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        force_suppress=True,
+        return_indices=True,
+        invalid_to_bottom=False,
+    )
+
+    tvm.testing.assert_allclose(result[0].numpy(), ref_indices)
+    tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count)
+    np.testing.assert_array_equal(ref_indices, np.array([[0, 2, -1, -1]], 
dtype="int32"))
+    np.testing.assert_array_equal(ref_valid_box_count, np.array([[2]], 
dtype="int32"))
+
+
[email protected]_llvm
+def test_nms_e2e_max_output_size():
+    """Validate that max_output_size truncates the kept boxes after score 
sorting."""
+
+    raw_data = np.array(
+        [
+            [
+                [0.0, 0.97, 0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.95, 2.0, 2.0, 3.0, 3.0],
+                [0.0, 0.93, 4.0, 4.0, 5.0, 5.0],
+                [0.0, 0.91, 6.0, 6.0, 7.0, 7.0],
+            ]
+        ],
+        dtype="float32",
+    )
+    valid_count_np, filtered_data_np, filtered_indices_np = 
_prepare_nms_inputs(raw_data)
+    ref_indices, ref_valid_box_count = 
tvm.topi.testing.non_max_suppression_python(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        max_output_size=2,
+        iou_threshold=1,
+        force_suppress=False,
+        top_k=-1,
+        coord_start=2,
+        score_index=1,
+        id_index=0,
+        return_indices=True,
+        invalid_to_bottom=False,
+    )
+    result = _run_nms_e2e(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        max_output_size=2,
+        iou_threshold=1,
+        return_indices=True,
+        invalid_to_bottom=False,
+    )
+
+    tvm.testing.assert_allclose(result[0].numpy(), ref_indices)
+    tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count)
+    np.testing.assert_array_equal(ref_indices, np.array([[0, 1, -1, -1]], 
dtype="int32"))
+    np.testing.assert_array_equal(ref_valid_box_count, np.array([[2]], 
dtype="int32"))
+
+
[email protected]_llvm
+def test_nms_e2e_multi_batch():
+    """Validate that classic NMS processes each batch independently."""
+
+    raw_data = np.array(
+        [
+            [
+                [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.90, 0.05, 0.05, 1.05, 1.05],
+                [1.0, 0.80, 2.0, 2.0, 3.0, 3.0],
+                [-1.0, 0.99, 8.0, 8.0, 9.0, 9.0],
+            ],
+            [
+                [1.0, 0.96, 0.0, 0.0, 1.0, 1.0],
+                [2.0, 0.94, 0.04, 0.04, 1.04, 1.04],
+                [2.0, 0.88, 3.0, 3.0, 4.0, 4.0],
+                [2.0, 0.30, 6.0, 6.0, 7.0, 7.0],
+            ],
+        ],
+        dtype="float32",
+    )
+    valid_count_np, filtered_data_np, filtered_indices_np = 
_prepare_nms_inputs(raw_data)
+    ref_indices, ref_valid_box_count = 
tvm.topi.testing.non_max_suppression_python(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        max_output_size=-1,
+        iou_threshold=0.5,
+        force_suppress=False,
+        top_k=-1,
+        coord_start=2,
+        score_index=1,
+        id_index=0,
+        return_indices=True,
+        invalid_to_bottom=False,
+    )
+    result = _run_nms_e2e(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        return_indices=True,
+        invalid_to_bottom=False,
+    )
+
+    tvm.testing.assert_allclose(result[0].numpy(), ref_indices)
+    tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count)
+    np.testing.assert_array_equal(
+        ref_indices,
+        np.array([[0, 2, -1, -1], [0, 1, 2, -1]], dtype="int32"),
+    )
+    np.testing.assert_array_equal(ref_valid_box_count, np.array([[2], [3]], 
dtype="int32"))
+
+
[email protected]_llvm
+def test_nms_e2e_invalid_to_bottom():
+    """Validate that invalid_to_bottom compacts only boxes that remain valid 
after NMS."""
+
+    raw_data = np.array(
+        [
+            [
+                [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.90, 0.05, 0.05, 1.05, 1.05],
+                [1.0, 0.85, 0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.60, 2.0, 2.0, 3.0, 3.0],
+                [-1.0, 0.99, 8.0, 8.0, 9.0, 9.0],
+            ]
+        ],
+        dtype="float32",
+    )
+    valid_count_np, filtered_data_np, filtered_indices_np = 
_prepare_nms_inputs(raw_data)
+    ref_out_data = tvm.topi.testing.non_max_suppression_python(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        max_output_size=-1,
+        iou_threshold=0.5,
+        force_suppress=False,
+        top_k=-1,
+        coord_start=2,
+        score_index=1,
+        id_index=0,
+        return_indices=False,
+        invalid_to_bottom=True,
+    )
+    result = _run_nms_e2e(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        return_indices=False,
+        invalid_to_bottom=True,
+    )
+    expected_out_data = np.array(
+        [
+            [
+                [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+                [1.0, 0.85, 0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.60, 2.0, 2.0, 3.0, 3.0],
+                [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
+                [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
+            ]
+        ],
+        dtype="float32",
+    )
+
+    tvm.testing.assert_allclose(result.numpy(), ref_out_data)
+    tvm.testing.assert_allclose(result.numpy(), expected_out_data)
+
+
[email protected]_llvm
+def test_nms_e2e_return_data_without_compaction():
+    """Validate the return_indices=False path when invalid boxes stay 
in-place."""
+
+    raw_data = np.array(
+        [
+            [
+                [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.90, 0.05, 0.05, 1.05, 1.05],
+                [1.0, 0.85, 0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.60, 2.0, 2.0, 3.0, 3.0],
+                [-1.0, 0.99, 8.0, 8.0, 9.0, 9.0],
+            ]
+        ],
+        dtype="float32",
+    )
+    valid_count_np, filtered_data_np, filtered_indices_np = 
_prepare_nms_inputs(raw_data)
+    ref_out_data = tvm.topi.testing.non_max_suppression_python(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        max_output_size=-1,
+        iou_threshold=0.5,
+        force_suppress=False,
+        top_k=-1,
+        coord_start=2,
+        score_index=1,
+        id_index=0,
+        return_indices=False,
+        invalid_to_bottom=False,
+    )
+    result = _run_nms_e2e(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        return_indices=False,
+        invalid_to_bottom=False,
+    )
+    expected_out_data = np.array(
+        [
+            [
+                [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+                [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
+                [1.0, 0.85, 0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.60, 2.0, 2.0, 3.0, 3.0],
+                [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
+            ]
+        ],
+        dtype="float32",
+    )
+
+    tvm.testing.assert_allclose(result.numpy(), ref_out_data)
+    tvm.testing.assert_allclose(result.numpy(), expected_out_data)
+
+
[email protected]_llvm
+def test_nms_e2e_index_remap():
+    """Validate that returned indices remap from filtered order back to 
original order."""
+
+    raw_data = np.array(
+        [
+            [
+                [-1.0, 0.99, 9.0, 9.0, 10.0, 10.0],
+                [0.0, 0.60, 4.0, 4.0, 5.0, 5.0],
+                [0.0, 0.10, 8.0, 8.0, 9.0, 9.0],
+                [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.90, 0.05, 0.05, 1.05, 1.05],
+                [1.0, 0.80, 2.0, 2.0, 3.0, 3.0],
+            ]
+        ],
+        dtype="float32",
+    )
+    valid_count_np, filtered_data_np, filtered_indices_np = 
_prepare_nms_inputs(raw_data)
+    ref_indices, ref_valid_box_count = 
tvm.topi.testing.non_max_suppression_python(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        max_output_size=-1,
+        iou_threshold=0.5,
+        force_suppress=False,
+        top_k=-1,
+        coord_start=2,
+        score_index=1,
+        id_index=0,
+        return_indices=True,
+        invalid_to_bottom=False,
+    )
+    result = _run_nms_e2e(
+        filtered_data_np,
+        valid_count_np,
+        filtered_indices_np,
+        return_indices=True,
+        invalid_to_bottom=False,
+    )
+
+    tvm.testing.assert_allclose(result[0].numpy(), ref_indices)
+    tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count)
+    np.testing.assert_array_equal(ref_indices, np.array([[3, 5, 1, -1, -1, 
-1]], dtype="int32"))
+    np.testing.assert_array_equal(ref_valid_box_count, np.array([[3]], 
dtype="int32"))
+
+
 def test_all_class_non_max_suppression_infer_struct_info():
     bb = relax.BlockBuilder()
     batch_size, num_classes, num_boxes = 10, 8, 5
@@ -450,11 +1302,11 @@ def _multibox_ref_numpy(
     boxes = np.zeros((B, N, 4), dtype=np.float32)
     for b in range(B):
         for a in range(N):
-            l, t, r, br = anchor[0, a, :]
-            ay = (t + br) * 0.5
-            ax = (l + r) * 0.5
-            ah = br - t
-            aw = r - l
+            left, top, right, bottom = anchor[0, a, :]
+            ay = (top + bottom) * 0.5
+            ax = (left + right) * 0.5
+            ah = bottom - top
+            aw = right - left
             ex, ey, ew, eh = loc[b, a, :]
             ycenter = ey * vy * ah + ay
             xcenter = ex * vx * aw + ax
diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py 
b/tests/python/relax/test_tvmscript_parser_op_vision.py
index f053e36744..370b68769e 100644
--- a/tests/python/relax/test_tvmscript_parser_op_vision.py
+++ b/tests/python/relax/test_tvmscript_parser_op_vision.py
@@ -75,6 +75,138 @@ def test_all_class_non_max_suppression():
     _check(foo, bb.get()["foo"])
 
 
+def test_get_valid_counts():
+    @R.function
+    def foo(
+        data: R.Tensor((10, 5, 6), "float32"),
+    ) -> R.Tuple(
+        R.Tensor((10,), "int32"),
+        R.Tensor((10, 5, 6), "float32"),
+        R.Tensor((10, 5), "int32"),
+    ):
+        gv: R.Tuple(
+            R.Tensor((10,), "int32"),
+            R.Tensor((10, 5, 6), "float32"),
+            R.Tensor((10, 5), "int32"),
+        ) = R.vision.get_valid_counts(data, score_threshold=0.5, id_index=0, 
score_index=1)
+        return gv
+
+    data = relax.Var("data", R.Tensor((10, 5, 6), "float32"))
+
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [data]):
+        gv = bb.emit(
+            relax.op.vision.get_valid_counts(
+                data, score_threshold=0.5, id_index=0, score_index=1
+            )
+        )
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_non_max_suppression_return_indices():
+    @R.function
+    def foo(
+        data: R.Tensor((2, 5, 6), "float32"),
+        valid_count: R.Tensor((2,), "int32"),
+        indices: R.Tensor((2, 5), "int32"),
+    ) -> R.Tuple(R.Tensor((2, 5), "int32"), R.Tensor((2, 1), "int32")):
+        gv: R.Tuple(R.Tensor((2, 5), "int32"), R.Tensor((2, 1), "int32")) = (
+            R.vision.non_max_suppression(
+                data,
+                valid_count,
+                indices,
+                max_output_size=-1,
+                iou_threshold=0.5,
+                force_suppress=False,
+                top_k=3,
+                coord_start=2,
+                score_index=1,
+                id_index=0,
+                return_indices=True,
+                invalid_to_bottom=False,
+            )
+        )
+        return gv
+
+    data = relax.Var("data", R.Tensor((2, 5, 6), "float32"))
+    valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+    indices = relax.Var("indices", R.Tensor((2, 5), "int32"))
+
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [data, valid_count, indices]):
+        gv = bb.emit(
+            relax.op.vision.non_max_suppression(
+                data,
+                valid_count,
+                indices,
+                max_output_size=-1,
+                iou_threshold=0.5,
+                force_suppress=False,
+                top_k=3,
+                coord_start=2,
+                score_index=1,
+                id_index=0,
+                return_indices=True,
+                invalid_to_bottom=False,
+            )
+        )
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_non_max_suppression_return_data():
+    @R.function
+    def foo(
+        data: R.Tensor((2, 5, 6), "float32"),
+        valid_count: R.Tensor((2,), "int32"),
+        indices: R.Tensor((2, 5), "int32"),
+    ) -> R.Tensor((2, 5, 6), "float32"):
+        gv: R.Tensor((2, 5, 6), "float32") = R.vision.non_max_suppression(
+            data,
+            valid_count,
+            indices,
+            max_output_size=-1,
+            iou_threshold=0.5,
+            force_suppress=False,
+            top_k=-1,
+            coord_start=2,
+            score_index=1,
+            id_index=0,
+            return_indices=False,
+            invalid_to_bottom=True,
+        )
+        return gv
+
+    data = relax.Var("data", R.Tensor((2, 5, 6), "float32"))
+    valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+    indices = relax.Var("indices", R.Tensor((2, 5), "int32"))
+
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [data, valid_count, indices]):
+        gv = bb.emit(
+            relax.op.vision.non_max_suppression(
+                data,
+                valid_count,
+                indices,
+                max_output_size=-1,
+                iou_threshold=0.5,
+                force_suppress=False,
+                top_k=-1,
+                coord_start=2,
+                score_index=1,
+                id_index=0,
+                return_indices=False,
+                invalid_to_bottom=True,
+            )
+        )
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
 def test_multibox_transform_loc():
     @R.function
     def foo(


Reply via email to