This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c00c66259a [Relax][ONNX] Support AllClassNMS Operator for ONNX 
Frontend (#18321)
c00c66259a is described below

commit c00c66259a8dd4cf197601c978c566ce2db9bc17
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Oct 1 16:09:59 2025 -0400

    [Relax][ONNX] Support AllClassNMS Operator for ONNX Frontend (#18321)
    
    Follow #18175 , this PR supports AllClassNMS Operator for ONNX Frontend
---
 include/tvm/relax/attrs/vision.h                   |  54 +++
 python/tvm/relax/frontend/onnx/onnx_frontend.py    | 179 +++++++-
 python/tvm/relax/op/__init__.py                    |   1 +
 python/tvm/relax/op/op_attrs.py                    |   5 +
 .../tvm/{topi/cpp => relax/op}/vision/__init__.py  |   9 +-
 .../__init__.py => relax/op/vision/_ffi_api.py}    |   7 +-
 python/tvm/relax/op/vision/nms.py                  |  75 ++++
 .../tvm/relax/transform/legalize_ops/__init__.py   |   1 +
 python/tvm/relax/transform/legalize_ops/vision.py  | 120 +++++
 python/tvm/script/ir_builder/relax/ir.py           |   2 +
 python/tvm/topi/__init__.py                        |   1 +
 python/tvm/topi/cpp/vision/__init__.py             |   1 +
 python/tvm/topi/{cpp => }/vision/__init__.py       |   9 +-
 python/tvm/topi/vision/nms.py                      | 500 +++++++++++++++++++++
 python/tvm/topi/vision/nms_util.py                 | 473 +++++++++++++++++++
 src/relax/ir/emit_te.h                             |   4 +
 src/relax/op/vision/nms.cc                         | 114 +++++
 src/relax/op/vision/nms.h                          |  44 ++
 src/te/operation/create_primfunc.cc                |   5 +-
 tests/python/relax/test_frontend_onnx.py           | 426 ++++++++++++++++++
 tests/python/relax/test_op_vision.py               |  90 ++++
 .../relax/test_tvmscript_parser_op_vision.py       |  80 ++++
 22 files changed, 2179 insertions(+), 21 deletions(-)

diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h
new file mode 100644
index 0000000000..2fd98533b5
--- /dev/null
+++ b/include/tvm/relax/attrs/vision.h
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/relax/attrs/vision.h
+ * \brief Auxiliary attributes for vision operators.
+ */
+#ifndef TVM_RELAX_ATTRS_VISION_H_
+#define TVM_RELAX_ATTRS_VISION_H_
+
+#include <tvm/ffi/string.h>
+#include <tvm/ir/attrs.h>
+#include <tvm/ir/type.h>
+#include <tvm/relax/expr.h>
+#include <tvm/runtime/object.h>
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Attributes used in AllClassNonMaximumSuppression operator */
+struct AllClassNonMaximumSuppressionAttrs
+    : public AttrsNodeReflAdapter<AllClassNonMaximumSuppressionAttrs> {
+  ffi::String output_format;
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<AllClassNonMaximumSuppressionAttrs>().def_ro(
+        "output_format", &AllClassNonMaximumSuppressionAttrs::output_format,
+        "Output format, onnx or tensorflow. Returns outputs in a way that can 
be easily "
+        "consumed by each frontend.");
+  }
+  
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AllClassNonMaximumSuppressionAttrs",
+                                    AllClassNonMaximumSuppressionAttrs, 
BaseAttrsNode);
+};  // struct AllClassNonMaximumSuppressionAttrs
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_ATTRS_VISION_H_
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 7a4a65df6e..7432967c29 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -3455,6 +3455,182 @@ class SequenceAt(OnnxOpConverter):
         return input_sequence[position]
 
 
+class NonMaxSuppression(OnnxOpConverter):
+    """Converts an onnx NonMaxSuppression node into an equivalent Relax 
expression."""
+
+    @classmethod
+    def _impl_v10(cls, bb, inputs, attr, params):
+        """
+        NonMaxSuppression performs non-maximum suppression (NMS) on all 
classes.
+
+        Inputs:
+        - boxes: (N, 4) tensor of bounding boxes in format [x1, y1, x2, y2]
+        - scores: (N, C) tensor of scores for each box and class
+        - max_output_boxes_per_class: maximum number of boxes to keep per class
+        - iou_threshold: IoU threshold for NMS
+        - score_threshold: score threshold for filtering
+
+        Outputs:
+        - selected_indices: (M, 3) tensor with [batch_idx, class_idx, box_idx]
+        """
+        boxes = inputs[0]
+        scores = inputs[1]
+        max_output_boxes_per_class = inputs[2] if len(inputs) > 2 else None
+        iou_threshold = inputs[3] if len(inputs) > 3 else None
+        score_threshold = inputs[4] if len(inputs) > 4 else None
+
+        center_point_box = attr.get("center_point_box", 0)
+
+        if max_output_boxes_per_class is not None and isinstance(
+            max_output_boxes_per_class, relax.Constant
+        ):
+            max_output_boxes_per_class = 
int(max_output_boxes_per_class.data.numpy())
+        elif max_output_boxes_per_class is not None and isinstance(
+            max_output_boxes_per_class, relax.Var
+        ):
+            var_name = max_output_boxes_per_class.name_hint
+            if var_name in params[1]:
+                _, param_value = params[1][var_name]
+                max_output_boxes_per_class = int(param_value.numpy().item())
+            else:
+                max_output_boxes_per_class = 100  # Default value
+        else:
+            max_output_boxes_per_class = 100  # Default value
+
+        if iou_threshold is not None and isinstance(iou_threshold, 
relax.Constant):
+            iou_threshold = float(iou_threshold.data.numpy())
+        else:
+            iou_threshold = 0.5  # Default value
+
+        if score_threshold is not None and isinstance(score_threshold, 
relax.Constant):
+            score_threshold = float(score_threshold.data.numpy())
+        elif score_threshold is not None and isinstance(score_threshold, 
relax.Var):
+            var_name = score_threshold.name_hint
+            if var_name in params[1]:
+                _, param_value = params[1][var_name]
+                score_threshold = float(param_value.numpy().item())
+            else:
+                score_threshold = 0.0  # Default value
+        else:
+            score_threshold = 0.0  # Default value
+
+        if center_point_box != 0:
+            split_result = relax.op.split(boxes, 4, axis=2)
+            xc = split_result[0]
+            yc = split_result[1]
+            w = split_result[2]
+            h = split_result[3]
+            half_w = w / relax.const(2.0, boxes.struct_info.dtype)
+            half_h = h / relax.const(2.0, boxes.struct_info.dtype)
+            x1 = xc - half_w
+            x2 = xc + half_w
+            y1 = yc - half_h
+            y2 = yc + half_h
+            boxes = relax.op.concat([y1, x1, y2, x2], axis=2)
+
+        nms_out = bb.normalize(
+            relax.op.vision.all_class_non_max_suppression(
+                boxes,
+                scores,
+                relax.const(max_output_boxes_per_class, dtype="int64"),
+                relax.const(iou_threshold, dtype="float32"),
+                relax.const(score_threshold, dtype="float32"),
+                output_format="onnx",
+            )
+        )
+
+        selected_indices = bb.emit(relax.TupleGetItem(nms_out, 0))
+
+        return selected_indices
+
+
+class AllClassNMS(OnnxOpConverter):
+    """Converts an onnx AllClassNMS node into an equivalent Relax 
expression."""
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        """
+        AllClassNMS performs non-maximum suppression (NMS) on all classes.
+
+        Inputs:
+        - boxes: (N, 4) tensor of bounding boxes in format [x1, y1, x2, y2]
+        - scores: (N, C) tensor of scores for each box and class
+        - max_output_boxes_per_class: maximum number of boxes to keep per class
+        - iou_threshold: IoU threshold for NMS
+        - score_threshold: score threshold for filtering
+
+        Outputs:
+        - selected_indices: (M, 3) tensor with [batch_idx, class_idx, box_idx]
+        """
+        boxes = inputs[0]
+        scores = inputs[1]
+        max_output_boxes_per_class = inputs[2] if len(inputs) > 2 else None
+        iou_threshold = inputs[3] if len(inputs) > 3 else None
+        score_threshold = inputs[4] if len(inputs) > 4 else None
+
+        center_point_box = attr.get("center_point_box", 0)
+
+        if max_output_boxes_per_class is not None and isinstance(
+            max_output_boxes_per_class, relax.Constant
+        ):
+            max_output_boxes_per_class = 
int(max_output_boxes_per_class.data.numpy())
+        elif max_output_boxes_per_class is not None and isinstance(
+            max_output_boxes_per_class, relax.Var
+        ):
+            var_name = max_output_boxes_per_class.name_hint
+            if var_name in params[1]:
+                _, param_value = params[1][var_name]
+                max_output_boxes_per_class = int(param_value.numpy().item())
+            else:
+                max_output_boxes_per_class = 100  # Default value
+        else:
+            max_output_boxes_per_class = 100  # Default value
+
+        if iou_threshold is not None and isinstance(iou_threshold, 
relax.Constant):
+            iou_threshold = float(iou_threshold.data.numpy())
+        else:
+            iou_threshold = 0.5  # Default value
+
+        if score_threshold is not None and isinstance(score_threshold, 
relax.Constant):
+            score_threshold = float(score_threshold.data.numpy())
+        elif score_threshold is not None and isinstance(score_threshold, 
relax.Var):
+            var_name = score_threshold.name_hint
+            if var_name in params[1]:
+                _, param_value = params[1][var_name]
+                score_threshold = float(param_value.numpy().item())
+            else:
+                score_threshold = 0.0  # Default value
+        else:
+            score_threshold = 0.0  # Default value
+
+        if center_point_box != 0:
+            split_result = relax.op.split(boxes, 4, axis=2)
+            xc = split_result[0]
+            yc = split_result[1]
+            w = split_result[2]
+            h = split_result[3]
+            half_w = w / relax.const(2.0, boxes.struct_info.dtype)
+            half_h = h / relax.const(2.0, boxes.struct_info.dtype)
+            x1 = xc - half_w
+            x2 = xc + half_w
+            y1 = yc - half_h
+            y2 = yc + half_h
+            boxes = relax.op.concat([y1, x1, y2, x2], axis=2)
+
+        nms_out = bb.normalize(
+            relax.op.vision.all_class_non_max_suppression(
+                boxes,
+                scores,
+                relax.const(max_output_boxes_per_class, dtype="int64"),
+                relax.const(iou_threshold, dtype="float32"),
+                relax.const(score_threshold, dtype="float32"),
+                output_format="onnx",
+            )
+        )
+
+        return nms_out
+
+
 def _get_convert_map():
     return {
         # defs/experimental
@@ -3605,7 +3781,8 @@ def _get_convert_map():
         # "LRN": LRN,
         # "MaxRoiPool": MaxRoiPool,
         # "RoiAlign": RoiAlign,
-        # "NonMaxSuppression": NonMaxSuppression,
+        "NonMaxSuppression": NonMaxSuppression,
+        "AllClassNMS": AllClassNMS,
         # "GridSample": GridSample,
         "Upsample": Upsample,
         # others
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 6ea8305eca..19096decd9 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -155,6 +155,7 @@ from .unary import (
     tanh,
     trunc,
 )
+from .vision import all_class_non_max_suppression
 
 
 def _register_op_make():
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index 4062aae0c7..229a789a45 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -239,6 +239,11 @@ class AttentionAttrs(Attrs):
     """Attributes used in attention operator"""
 
 
+@tvm_ffi.register_object("relax.attrs.AllClassNonMaximumSuppressionAttrs")
+class AllClassNonMaximumSuppressionAttrs(Attrs):
+    """Attributes for vision.all_class_non_max_suppression"""
+
+
 @tvm_ffi.register_object("relax.attrs.Conv1DAttrs")
 class Conv1DAttrs(Attrs):
     """Attributes for nn.conv1d"""
diff --git a/python/tvm/topi/cpp/vision/__init__.py 
b/python/tvm/relax/op/vision/__init__.py
similarity index 84%
copy from python/tvm/topi/cpp/vision/__init__.py
copy to python/tvm/relax/op/vision/__init__.py
index 8acbb38610..be45458d36 100644
--- a/python/tvm/topi/cpp/vision/__init__.py
+++ b/python/tvm/relax/op/vision/__init__.py
@@ -14,10 +14,5 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
-"""FFI for vision TOPI ops and schedules"""
-import tvm_ffi
-
-from . import yolo
-
-tvm_ffi.init_ffi_api("topi.vision", "tvm.topi.cpp.vision")
+"""VISION operators."""
+from .nms import *
diff --git a/python/tvm/topi/cpp/vision/__init__.py 
b/python/tvm/relax/op/vision/_ffi_api.py
similarity index 86%
copy from python/tvm/topi/cpp/vision/__init__.py
copy to python/tvm/relax/op/vision/_ffi_api.py
index 8acbb38610..8af761dc5a 100644
--- a/python/tvm/topi/cpp/vision/__init__.py
+++ b/python/tvm/relax/op/vision/_ffi_api.py
@@ -14,10 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
-"""FFI for vision TOPI ops and schedules"""
+"""Constructor APIs"""
 import tvm_ffi
 
-from . import yolo
-
-tvm_ffi.init_ffi_api("topi.vision", "tvm.topi.cpp.vision")
+tvm_ffi.init_ffi_api("relax.op.vision", __name__)
diff --git a/python/tvm/relax/op/vision/nms.py 
b/python/tvm/relax/op/vision/nms.py
new file mode 100644
index 0000000000..3714b00b01
--- /dev/null
+++ b/python/tvm/relax/op/vision/nms.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Non-maximum suppression operator"""
+# from tvm import relax  # Unused import
+from . import _ffi_api
+
+
+def all_class_non_max_suppression(
+    boxes,
+    scores,
+    max_output_boxes_per_class,
+    iou_threshold,
+    score_threshold,
+    output_format="onnx",
+):
+    """Non-maximum suppression operator for object detection, corresponding to 
ONNX
+    NonMaxSuppression and TensorFlow combined_non_max_suppression.
+    NMS is performed for each class separately.
+
+    Parameters
+    ----------
+    boxes : relax.Expr
+        3-D tensor with shape (batch_size, num_boxes, 4)
+    scores: relax.Expr
+        3-D tensor with shape (batch_size, num_classes, num_boxes)
+    max_output_boxes_per_class : relax.Expr
+        The maxinum number of output selected boxes per class
+    iou_threshold : relax.Expr
+        IoU test threshold
+    score_threshold : relax.Expr
+        Score threshold to filter out low score boxes early
+    output_format : str, optional
+        "onnx" or "tensorflow", see below.
+
+    Returns
+    -------
+    out : relax.Expr
+        If `output_format` is "onnx", the output is two tensors. The first is 
`indices` of size
+        `(batch_size * num_class* num_boxes , 3)` and the second is a scalar 
tensor
+        `num_total_detection` of shape `(1,)` representing the total number of 
selected
+        boxes. The three values in `indices` encode batch, class, and box 
indices.
+        Rows of `indices` are ordered such that selected boxes from batch 0, 
class 0 come
+        first, in descending of scores, followed by boxes from batch 0, class 
1 etc. Out of
+        `batch_size * num_class* num_boxes` rows of indices, only the first 
`num_total_detection`
+        rows are valid.
+
+        TODO: Implement true dynamic output shapes to match ONNX Runtime 
behavior exactly.
+        This would eliminate the need for manual trimming and improve memory 
efficiency.
+        If `output_format` is "tensorflow", the output is three tensors, the 
first
+        is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the 
second is `scores` of
+        size `(batch_size, num_class * num_boxes)`, and the third is 
`num_total_detection` of size
+        `(batch_size,)` representing the total number of selected boxes per 
batch. The two values
+        in `indices` encode class and box indices. Of num_class * num_boxes 
boxes in `indices` at
+        batch b, only the first `num_total_detection[b]` entries are valid. 
The second axis of
+        `indices` and `scores` are sorted within each class by box scores, but 
not across classes.
+        So the box indices and scores for the class 0 come first in a sorted 
order, followed by
+        the class 1 etc.
+    """
+    return _ffi_api.all_class_non_max_suppression(
+        boxes, scores, max_output_boxes_per_class, iou_threshold, 
score_threshold, output_format
+    )
diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py 
b/python/tvm/relax/transform/legalize_ops/__init__.py
index b4aba0291f..5614d02296 100644
--- a/python/tvm/relax/transform/legalize_ops/__init__.py
+++ b/python/tvm/relax/transform/legalize_ops/__init__.py
@@ -31,3 +31,4 @@ from . import qdq
 from . import search
 from . import statistical
 from . import unary
+from . import vision
diff --git a/python/tvm/relax/transform/legalize_ops/vision.py 
b/python/tvm/relax/transform/legalize_ops/vision.py
new file mode 100644
index 0000000000..f910f62cec
--- /dev/null
+++ b/python/tvm/relax/transform/legalize_ops/vision.py
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Default legalization function for vision network related operators."""
+from tvm import topi, te
+from tvm import relax
+from ...block_builder import BlockBuilder
+from ...expr import Call, Expr
+from .common import register_legalize
+
+
+def _create_onnx_nms_te(boxes, scores, max_output_boxes_per_class, 
iou_threshold, score_threshold):
+    """Create a proper NMS implementation that follows the correct algorithm"""
+    scores_shape = list(scores.shape)
+    if len(scores_shape) == 3:
+        batch, num_classes, _ = scores_shape
+    elif len(scores_shape) == 2:
+        num_classes, _ = scores_shape
+        batch = 1
+    else:
+        raise ValueError(f"Unexpected scores shape: {scores_shape}")
+
+    if hasattr(max_output_boxes_per_class, "data"):
+        max_boxes = int(max_output_boxes_per_class.data.numpy())
+    else:
+        max_boxes = 3  # Default value
+
+    expected_detections = batch * num_classes * max_boxes
+
+    selected_indices_full, _ = topi.vision.all_class_non_max_suppression(
+        boxes, scores, max_output_boxes_per_class, iou_threshold, 
score_threshold, "onnx"
+    )
+
+    def slice_to_onnx_shape(data, expected_size):
+        def compute_element(i, j):
+            return tvm.tir.if_then_else(i < expected_size, data[i, j], 
tvm.tir.Cast("int64", 0))
+
+        return te.compute((expected_size, 3), compute_element, 
name="sliced_indices")
+
+    sliced_indices = slice_to_onnx_shape(selected_indices_full, 
expected_detections)
+
+    actual_detections = te.compute(
+        (1,), lambda i: tvm.tir.Cast("int64", expected_detections), 
name="actual_detections"
+    )
+
+    return [sliced_indices, actual_detections]
+
+
+@register_legalize("relax.vision.all_class_non_max_suppression")
+def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> 
Expr:
+    """Legalize all_class_non_max_suppression with fixed shape output.
+
+    Note: This implementation outputs fixed-size tensors with trailing garbage 
data.
+    Only the first `num_total_detection` rows contain valid data. Users should 
use
+    the `valid_count` tensor to determine how many rows are actually valid.
+
+    For complete ONNX compatibility, users can post-process the output:
+    ```python
+    selected_indices, valid_count = nms_output
+    actual_count = int(valid_count.numpy()[0])
+    valid_indices = selected_indices.numpy()[:actual_count, :]
+    ```
+    """
+    boxes = call.args[0]
+    scores = call.args[1]
+    max_output_boxes_per_class = call.args[2]
+    iou_threshold = call.args[3]
+    score_threshold = call.args[4]
+    output_format = call.attrs.output_format
+
+    scores_shape = scores.struct_info.shape
+    if len(scores_shape) == 3:
+        _, _, num_boxes = scores_shape
+    elif len(scores_shape) == 2:
+        _, num_boxes = scores_shape
+    else:
+        raise ValueError(f"Unexpected scores shape: {scores_shape}")
+
+    if isinstance(max_output_boxes_per_class, relax.Constant):
+        max_boxes_val = int(max_output_boxes_per_class.data.numpy())
+    else:
+        max_boxes_val = int(num_boxes)
+
+    # Get NMS result with fixed shape from TOPI
+    nms_result = block_builder.call_te(
+        topi.vision.all_class_non_max_suppression,
+        boxes,
+        scores,
+        max_boxes_val,
+        iou_threshold,
+        score_threshold,
+        output_format,
+    )
+
+    # TODO: Implement dynamic output trimming for better memory efficiency
+    # Current approach returns fixed-size output with trailing garbage data
+    # Future improvements could include:
+    # 1. Dynamic strided_slice based on num_total_detections
+    # 2. Custom Relax operator with true dynamic shapes
+    # 3. VM builtin functions for runtime shape adjustment
+    # 4. Symbolic shape inference in Relax IR
+    #
+    # For now, users should trim manually:
+    # actual_count = int(num_total_detections.numpy()[0])
+    # valid_indices = selected_indices.numpy()[:actual_count, :]
+
+    return nms_result
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 3fa735197a..f221a13089 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -188,6 +188,7 @@ from tvm.relax.op import (
     wrap_param,
     zeros,
     zeros_like,
+    vision,
 )
 from tvm.relax.op.builtin import stop_lift_params
 from tvm.relax.struct_info import StructInfo
@@ -950,4 +951,5 @@ __all__ = [
     "nn",
     "ccl",
     "erf",
+    "vision",
 ]
diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py
index 9503aea0cd..c73e8bf54c 100644
--- a/python/tvm/topi/__init__.py
+++ b/python/tvm/topi/__init__.py
@@ -50,6 +50,7 @@ from .signal import *
 from . import nn
 from . import utils
 from . import image
+from . import vision
 from . import gpu
 
 # error reporting
diff --git a/python/tvm/topi/cpp/vision/__init__.py 
b/python/tvm/topi/cpp/vision/__init__.py
index 8acbb38610..467ce70fbd 100644
--- a/python/tvm/topi/cpp/vision/__init__.py
+++ b/python/tvm/topi/cpp/vision/__init__.py
@@ -19,5 +19,6 @@
 import tvm_ffi
 
 from . import yolo
+from ...vision import nms
 
 tvm_ffi.init_ffi_api("topi.vision", "tvm.topi.cpp.vision")
diff --git a/python/tvm/topi/cpp/vision/__init__.py 
b/python/tvm/topi/vision/__init__.py
similarity index 84%
copy from python/tvm/topi/cpp/vision/__init__.py
copy to python/tvm/topi/vision/__init__.py
index 8acbb38610..f12758bb9c 100644
--- a/python/tvm/topi/cpp/vision/__init__.py
+++ b/python/tvm/topi/vision/__init__.py
@@ -14,10 +14,5 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
-"""FFI for vision TOPI ops and schedules"""
-import tvm_ffi
-
-from . import yolo
-
-tvm_ffi.init_ffi_api("topi.vision", "tvm.topi.cpp.vision")
+"""Vision operators."""
+from .nms import *
diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py
new file mode 100644
index 0000000000..f4aae45ef9
--- /dev/null
+++ b/python/tvm/topi/vision/nms.py
@@ -0,0 +1,500 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=import-error, invalid-name, no-member, too-many-locals, 
too-many-arguments, undefined-variable, too-many-nested-blocks, 
too-many-branches, too-many-statements, too-many-function-args
+"""Non-maximum suppression operator"""
+import tvm
+from tvm import te
+
+from tvm.tir import if_then_else
+
+from ..sort import argsort
+from ..math import cast
+from ..transform import reshape, gather
+from .. import reduction
+from ..scan import cumsum
+from .nms_util import (
+    binary_search,
+    collect_selected_indices,
+    collect_selected_indices_and_scores,
+    run_all_class_nms,
+)
+
+
+def get_valid_counts(
+    data, score_threshold=0, id_index=0, score_index=1
+):  # pylint: disable=unused-argument
+    """Get valid count of bounding boxes given a score threshold.
+    Also moves valid boxes to the top of input data.
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        Input data. 3-D tensor with shape [batch_size, num_anchors, 6]
+        or [batch_size, num_anchors, 5].
+    score_threshold : optional, float
+        Lower limit of score for valid bounding boxes.
+    id_index : optional, int
+        index of the class categories, -1 to disable.
+    score_index: optional, int
+        Index of the scores/confidence of boxes.
+    Returns
+    -------
+    valid_count : tvm.te.Tensor
+        1-D tensor for valid number of boxes.
+    out_tensor : tvm.te.Tensor
+        Rearranged data tensor.
+    out_indices: tvm.te.Tensor or numpy NDArray
+        Related index in input data.
+    """
+    if isinstance(score_threshold, (float, int)):
+        score_threshold = tvm.tir.const(score_threshold, dtype=data.dtype)
+    # id_index_const = tvm.tir.const(id_index, "int32")  # Unused
+    # score_index_const = tvm.tir.const(score_index, "int32")  # Unused
+    return (
+        te.compute((data.shape[0],), lambda i: data.shape[1], 
name="valid_count"),
+        data,
+        te.compute((data.shape[0], data.shape[1]), lambda i, j: j, 
name="out_indices"),
+    )
+
+
+def _nms_loop(
+    ib,
+    batch_size,
+    top_k,
+    iou_threshold,
+    max_output_size,
+    valid_count,
+    on_new_valid_box_func,
+    on_new_invalidated_box_func,
+    needs_bbox_check_func,
+    calc_overlap_func,
+    out_scores,
+    num_valid_boxes,
+    score_threshold=None,
+):
+    def nms_inner_loop(ib, i, j, nkeep, num_valid_boxes_local):
+        on_new_valid_box_func(ib, 0, num_valid_boxes_local[0], i, j)
+        num_valid_boxes_local[0] += 1
+
+        num_boxes_to_check = nkeep - (j + 1)
+
+        with ib.for_range(0, num_boxes_to_check, name="_k", kind="parallel") 
as _k:
+            k = j + 1 + _k
+
+            with ib.if_scope(
+                tvm.tir.all(
+                    k < nkeep,
+                    out_scores[i, k] > 0,  # is the box k still valid?
+                    needs_bbox_check_func(i, j, k),
+                )
+            ):
+                iou = calc_overlap_func(i, j, k)
+
+                with ib.if_scope(iou >= iou_threshold):
+                    out_scores[i, k] = -1.0
+                    on_new_invalidated_box_func(i, k)
+
+    with ib.for_range(0, batch_size, name="i") as i:
+        nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), 
top_k, valid_count[i])
+        # Use max_output_size directly without if_then_else
+        # max_output_size = if_then_else(max_output_size > te.const(0), 
max_output_size, nkeep)
+
+        with ib.if_scope(tvm.tir.all(iou_threshold > te.const(0), 
valid_count[i] > te.const(0))):
+            num_valid_boxes_local = ib.allocate(
+                "int32", (1,), name="num_valid_boxes_local", scope="local"
+            )
+            num_valid_boxes_local[0] = 0
+
+            # Use for_range to iterate through all boxes, but limit selection 
count
+            with ib.for_range(0, nkeep, name="j") as j:
+                with ib.if_scope(
+                    tvm.tir.all(
+                        out_scores[i, j] > -1.0,  # box is still valid
+                        num_valid_boxes_local[0] < max_output_size,  # haven't 
reached max limit
+                    )
+                ):
+                    if score_threshold is not None:
+                        with ib.if_scope(out_scores[i, j] > 
score_threshold[()]):
+                            nms_inner_loop(ib, i, j, nkeep, 
num_valid_boxes_local)
+                    else:
+                        nms_inner_loop(ib, i, j, nkeep, num_valid_boxes_local)
+
+            num_valid_boxes[i] = num_valid_boxes_local[0]
+
+        with ib.else_scope():
+            num_valid_boxes[i] = 0
+
+    return ib.get()
+
+
+def _get_valid_box_count(scores, score_threshold):
+    batch_classes, num_boxes = scores.shape
+
+    def searchsorted_ir(scores, score_thresh, valid_count):
+        ib = tvm.tir.ir_builder.create()
+        scores = ib.buffer_ptr(scores)
+        valid_count = ib.buffer_ptr(valid_count)
+
+        with ib.for_range(0, batch_classes, name="i", kind="parallel") as i:
+            if hasattr(score_threshold, "shape"):
+                if len(score_threshold.shape) == 0:
+                    score_thresh_scalar = score_thresh[()]
+                elif len(score_threshold.shape) == 1 and 
score_threshold.shape[0] > 0:
+                    score_thresh_scalar = score_thresh[0]
+                else:
+                    score_thresh_scalar = tvm.tir.FloatImm("float32", 0.0)
+            else:
+                score_thresh_scalar = score_threshold
+            binary_search(ib, i, num_boxes, scores, score_thresh_scalar, 
valid_count)
+
+        return ib.get()
+
+    scores_buf = tvm.tir.decl_buffer(scores.shape, scores.dtype, "scores_buf", 
data_alignment=8)
+    searchsorted_buf = tvm.tir.decl_buffer(
+        (batch_classes,), "int32", "searchsorted", data_alignment=8
+    )
+
+    if hasattr(score_threshold, "shape"):
+        score_thresh_buf = tvm.tir.decl_buffer(
+            score_threshold.shape, score_threshold.dtype, "score_thresh_buf", 
data_alignment=8
+        )
+        return te.extern(
+            [(batch_classes,)],
+            [scores, score_threshold],
+            lambda ins, outs: searchsorted_ir(ins[0], ins[1], outs[0]),
+            dtype=["int32"],
+            in_buffers=[scores_buf, score_thresh_buf],
+            out_buffers=[searchsorted_buf],
+            name="searchsorted",
+            tag="searchsorted",
+        )
+    else:
+
+        def searchsorted_ir_scalar(scores, valid_count):
+            ib = tvm.tir.ir_builder.create()
+            scores = ib.buffer_ptr(scores)
+            valid_count = ib.buffer_ptr(valid_count)
+
+            with ib.for_range(0, batch_classes, name="i", kind="parallel") as 
i:
+                if isinstance(score_threshold, te.Tensor):
+                    if len(score_threshold.shape) == 0:
+                        score_thresh_tir = score_threshold()
+                    elif len(score_threshold.shape) == 1 and 
score_threshold.shape[0] == 1:
+                        score_thresh_tir = score_threshold[0]
+                    else:
+                        score_thresh_tir = tvm.tir.FloatImm("float32", 0.0)
+                else:
+                    score_thresh_tir = tvm.tir.FloatImm("float32", 
float(score_threshold))
+                binary_search(ib, i, num_boxes, scores, score_thresh_tir, 
valid_count)
+
+            return ib.get()
+
+        return te.extern(
+            [(batch_classes,)],
+            [scores],
+            lambda ins, outs: searchsorted_ir_scalar(ins[0], outs[0]),
+            dtype=["int32"],
+            in_buffers=[scores_buf],
+            out_buffers=[searchsorted_buf],
+            name="searchsorted",
+            tag="searchsorted",
+        )
+
+
+def _collect_selected_indices_ir(
+    num_class, selected_indices, num_detections, row_offsets, out, 
max_output_boxes_per_class=None
+):
+    batch_classes, _ = selected_indices.shape
+
+    ib = tvm.tir.ir_builder.create()
+
+    selected_indices = ib.buffer_ptr(selected_indices)
+    num_detections = ib.buffer_ptr(num_detections)
+    row_offsets = ib.buffer_ptr(row_offsets)
+    out = ib.buffer_ptr(out)
+
+    # Initialize output buffer to zero
+    # Calculate the actual output shape based on max_output_boxes_per_class
+    if isinstance(max_output_boxes_per_class, int):
+        max_output_rows = batch_classes * max_output_boxes_per_class
+    else:
+        # Fallback to a reasonable default if max_output_boxes_per_class is 
not an integer
+        max_output_rows = batch_classes * 10
+    with ib.for_range(0, max_output_rows, name="init_i") as init_i:
+        with ib.for_range(0, 3, name="init_j") as init_j:  # 3 columns
+            out[init_i, init_j] = cast(0, "int64")
+
+    with ib.for_range(0, batch_classes, name="i", kind="parallel") as i:
+        i = cast(i, "int64")
+        batch_id = i // num_class
+        class_id = i % num_class
+
+        if isinstance(max_output_boxes_per_class, int):
+            limit = tvm.tir.min(
+                num_detections[i], tvm.tir.IntImm("int32", 
max_output_boxes_per_class)
+            )
+        elif isinstance(max_output_boxes_per_class, te.Tensor):
+            if len(max_output_boxes_per_class.shape) == 0:
+                max_boxes_val = max_output_boxes_per_class[()]
+            else:
+                max_boxes_val = max_output_boxes_per_class[0]
+            limit = tvm.tir.min(num_detections[i], max_boxes_val)
+        else:
+            limit = num_detections[i]
+
+        with ib.for_range(0, limit, name="j") as j:
+            out[row_offsets[i] + j, 0] = batch_id
+            out[row_offsets[i] + j, 1] = class_id
+            out[row_offsets[i] + j, 2] = cast(selected_indices[i, j], "int64")
+
+    return ib.get()
+
+
+def _collect_selected_indices_and_scores_ir(
+    selected_indices,
+    selected_scores,
+    num_detections,
+    row_offsets,
+    num_total_detections,
+    collected_indices,
+    collected_scores,
+):
+    batch_size, num_class = row_offsets.shape
+    num_boxes = selected_indices.shape[1]
+
+    ib = tvm.tir.ir_builder.create()
+
+    selected_indices = ib.buffer_ptr(selected_indices)
+    selected_scores = ib.buffer_ptr(selected_scores)
+    num_detections = ib.buffer_ptr(num_detections)
+    row_offsets = ib.buffer_ptr(row_offsets)
+    num_total_detections = ib.buffer_ptr(num_total_detections)
+    collected_indices = ib.buffer_ptr(collected_indices)
+    collected_scores = ib.buffer_ptr(collected_scores)
+    zero = cast(0, "int64")
+
+    with ib.for_range(0, batch_size * num_class, name="i", kind="parallel") as 
i:
+        i = cast(i, "int64")
+        batch_id = i // num_class
+        class_id = i % num_class
+
+        with ib.for_range(0, num_boxes, name="j") as j:
+            with ib.if_scope(j < num_detections[batch_id, class_id]):
+                offset = row_offsets[batch_id, class_id] + j
+                collected_indices[batch_id, offset, 0] = class_id
+                collected_indices[batch_id, offset, 1] = 
cast(selected_indices[i, j], "int64")
+                collected_scores[batch_id, offset] = selected_scores[i, j]
+            with ib.else_scope():
+                offset = (
+                    num_total_detections[batch_id]
+                    + class_id * num_boxes
+                    - row_offsets[batch_id, class_id]
+                    + j
+                    - num_detections[batch_id, class_id]
+                )
+                collected_indices[batch_id, offset, 0] = zero
+                collected_indices[batch_id, offset, 1] = zero
+                collected_scores[batch_id, offset] = 0.0
+
+    return ib.get()
+
+
+def all_class_non_max_suppression(
+    boxes,
+    scores,
+    max_output_boxes_per_class,
+    iou_threshold,
+    score_threshold,
+    output_format="onnx",
+    output_shape=None,
+):
+    """Non-maximum suppression operator for object detection, corresponding to 
ONNX
+    NonMaxSuppression and TensorFlow combined_non_max_suppression.
+    NMS is performed for each class separately.
+    Parameters
+    ----------
+    boxes : tvm.te.Tensor
+        3-D tensor with shape (batch_size, num_boxes, 4)
+    scores: tvm.te.Tensor
+        3-D tensor with shape (batch_size, num_classes, num_boxes)
+    max_output_boxes_per_class : int or tvm.te.Tensor, optional
+        The maxinum number of output selected boxes per class
+    iou_threshold : float or tvm.te.Tensor, optionaIl
+        IoU test threshold
+    score_threshold : float or tvm.te.Tensor, optional
+        Score threshold to filter out low score boxes early
+    output_format : str, optional
+        "onnx" or "tensorflow", see below.
+    Returns
+    -------
+    out : list of tvm.te.Tensor
+        If `output_format` is "onnx", the output is two tensors. The first is 
`indices` of size
+        `(batch_size * num_class* num_boxes , 3)` and the second is a scalar 
tensor
+        `num_total_detection` of shape `(1,)` representing the total number of 
selected
+        boxes. The three values in `indices` encode batch, class, and box 
indices.
+        Rows of `indices` are ordered such that selected boxes from batch 0, 
class 0 come
+        first, in descending of scores, followed by boxes from batch 0, class 
1 etc. Out of
+        `batch_size * num_class* num_boxes` rows of indices, only the first 
`num_total_detection`
+        rows are valid.
+
+        .. note::
+            **Important**: The output tensor has a fixed size based on 
`max_output_boxes_per_class`,
+            but only the first `num_total_detection` rows contain valid data. 
The remaining rows
+            may contain garbage values. When comparing with ONNX Runtime or 
other implementations
+            that output dynamic shapes, you should only compare the first
+            `num_total_detection` rows.
+            Example:
+            ```python
+            selected_indices, valid_count = nms_output
+            actual_count = int(valid_count.numpy()[0])
+            valid_indices = selected_indices.numpy()[:actual_count, :]
+            ```
+        If `output_format` is "tensorflow", the output is three tensors, the 
first
+        is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the 
second is `scores` of
+        size `(batch_size, num_class * num_boxes)`, and the third is 
`num_total_detection` of size
+        `(batch_size,)` representing the total number of selected boxes per 
batch. The two values
+        in `indices` encode class and box indices. Of num_class * num_boxes 
boxes in `indices` at
+        batch b, only the first `num_total_detection[b]` entries are valid. 
The second axis of
+        `indices` and `scores` are sorted within each class by box scores, but 
not across classes.
+        So the box indices and scores for the class 0 come first in a sorted 
order, followed by
+        the class 1 etc.
+    """
+    batch, num_class, num_boxes = scores.shape
+    scores = reshape(scores, (batch * num_class, num_boxes))
+
+    sorted_indices = argsort(scores, axis=1, is_ascend=False, dtype="int32")
+    sorted_scores = gather(scores, 1, sorted_indices)
+
+    if not isinstance(score_threshold, te.Tensor):
+        score_threshold_tensor = te.compute((), lambda: score_threshold, 
name="score_threshold")
+    else:
+        score_threshold_tensor = score_threshold
+
+    valid_count = _get_valid_box_count(sorted_scores, score_threshold_tensor)
+
+    selected_indices, selected_scores, num_detections = run_all_class_nms(
+        boxes,
+        sorted_scores,
+        sorted_indices,
+        valid_count,
+        max_output_boxes_per_class,
+        iou_threshold,
+        _nms_loop,
+        return_scores=(output_format == "tensorflow"),
+        score_threshold=score_threshold_tensor,  # Passed score_threshold as 
tensor
+    )
+
+    if output_format == "onnx":
+        row_offsets = cumsum(num_detections, exclusive=True, dtype="int64")
+
+        def _sum_clamped_total():
+            if isinstance(max_output_boxes_per_class, int):
+                k_expr = tvm.tir.IntImm("int32", 
int(max_output_boxes_per_class))
+                clamped = te.compute(
+                    num_detections.shape,
+                    lambda i: tvm.tir.min(num_detections[i], k_expr),
+                    name="clamped_num",
+                )
+                return reduction.sum(cast(clamped, "int64"), axis=0)
+            if isinstance(max_output_boxes_per_class, tvm.tir.IntImm):
+                k_expr = tvm.tir.Cast("int32", max_output_boxes_per_class)
+                clamped = te.compute(
+                    num_detections.shape,
+                    lambda i: tvm.tir.min(num_detections[i], k_expr),
+                    name="clamped_num",
+                )
+                return reduction.sum(cast(clamped, "int64"), axis=0)
+            if isinstance(max_output_boxes_per_class, te.Tensor):
+                if len(max_output_boxes_per_class.shape) == 0:
+                    kb = te.compute(
+                        num_detections.shape,
+                        lambda i: cast(max_output_boxes_per_class, "int32"),
+                        name="k_broadcast",
+                    )
+                elif (
+                    len(max_output_boxes_per_class.shape) == 1
+                    and max_output_boxes_per_class.shape[0] == 1
+                ):
+                    kb = te.compute(
+                        num_detections.shape,
+                        lambda i: cast(max_output_boxes_per_class[0], "int32"),
+                        name="k_broadcast",
+                    )
+                else:
+                    return reduction.sum(cast(num_detections, "int64"), axis=0)
+
+                clamped = te.compute(
+                    num_detections.shape,
+                    lambda i: tvm.tir.min(num_detections[i], kb[i]),
+                    name="clamped_num",
+                )
+                return reduction.sum(cast(clamped, "int64"), axis=0)
+            return reduction.sum(cast(num_detections, "int64"), axis=0)
+
+        num_total_scalar = _sum_clamped_total()
+        num_total_detections = reshape(num_total_scalar, (1,))
+
+        if output_shape is not None:
+            selected_indices = collect_selected_indices(
+                num_class,
+                selected_indices,
+                num_detections,
+                row_offsets,
+                _collect_selected_indices_ir,
+                max_output_boxes_per_class=max_output_boxes_per_class,
+                output_shape=output_shape,
+            )
+        else:
+            # Use num_total_detections to enable dynamic trimming
+            # Pass image size for intelligent default estimation
+            input_image_size = None
+            if hasattr(scores, "shape") and len(scores.shape) >= 3:
+                # Extract image size from scores shape: (batch, num_classes, 
num_boxes)
+                # We can estimate image size from num_boxes (more boxes = 
larger image)
+                input_image_size = (scores.shape[2],)  # Use num_boxes as 
proxy for image size
+
+                # TODO: Improve image size estimation by:
+                # 1. Accepting actual image dimensions as parameters
+                # 2. Using model metadata to infer typical image sizes
+                # 3. Learning from historical detection patterns
+                # 4. Providing user-configurable estimation strategies
+
+            selected_indices = collect_selected_indices(
+                num_class,
+                selected_indices,
+                num_detections,
+                row_offsets,
+                _collect_selected_indices_ir,
+                max_output_boxes_per_class=max_output_boxes_per_class,
+                num_total_detections=num_total_detections,
+                input_image_size=input_image_size,
+            )
+        return [selected_indices, num_total_detections]
+
+    num_detections_per_batch = reshape(num_detections, (batch, num_class))
+    row_offsets = cumsum(num_detections_per_batch, exclusive=True, 
dtype="int64", axis=1)
+    num_total_detections = reduction.sum(cast(num_detections_per_batch, 
"int64"), axis=1)
+
+    selected_indices, selected_scores = collect_selected_indices_and_scores(
+        selected_indices,
+        selected_scores,
+        num_detections_per_batch,
+        row_offsets,
+        num_total_detections,
+        _collect_selected_indices_and_scores_ir,
+    )
+
+    return [selected_indices, selected_scores, num_total_detections]
diff --git a/python/tvm/topi/vision/nms_util.py 
b/python/tvm/topi/vision/nms_util.py
new file mode 100644
index 0000000000..1633c923e1
--- /dev/null
+++ b/python/tvm/topi/vision/nms_util.py
@@ -0,0 +1,473 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Common utilities used in Non-maximum suppression operators"""
+import tvm
+from tvm import te
+
+
+def _get_boundaries(output, box_idx):
+    l = tvm.te.min(
+        output[box_idx],
+        output[box_idx + 2],
+    )
+    t = tvm.te.min(
+        output[box_idx + 1],
+        output[box_idx + 3],
+    )
+    r = tvm.te.max(
+        output[box_idx],
+        output[box_idx + 2],
+    )
+    b = tvm.te.max(
+        output[box_idx + 1],
+        output[box_idx + 3],
+    )
+    return l, t, r, b
+
+
+def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
+    """Calculate overlap of two boxes."""
+    a_l, a_t, a_r, a_b = _get_boundaries(out_tensor, box_a_idx)
+    b_l, b_t, b_r, b_b = _get_boundaries(out_tensor, box_b_idx)
+
+    # Overlapping width and height
+    w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l))
+    h = tvm.te.max(0.0, tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t))
+
+    # Overlapping area
+    area = h * w
+
+    # total area of the figure formed by box a and box b
+    # except for overlapping area
+    u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area
+    return tvm.tir.Select(u <= 0.0, 0.0, area / u)
+
+
+def binary_search(ib, y, num_boxes, scores, score_threshold, out):
+    """Binary search for score_threshold on scores sorted in descending 
order"""
+    lo = ib.allocate("int32", (1,), name="lo", scope="local")
+    hi = ib.allocate("int32", (1,), name="hi", scope="local")
+
+    lo[0] = 0
+    hi[0] = num_boxes.astype("int32")
+
+    with ib.while_loop(lo[0] < hi[0]):
+        mid = (hi[0] + lo[0]) >> 1
+        with ib.if_scope(scores[y, mid] > score_threshold):
+            lo[0] = mid + 1
+        with ib.else_scope():
+            hi[0] = mid
+
+    out[y] = lo[0]
+
+
+def _estimate_max_detections(batch_class, input_image_size=None):
+    """Estimate maximum detections based on input image size and number of 
classes.
+
+    This provides a more intelligent default for production environments.
+    """
+    if input_image_size is not None:
+        # Estimate based on image size: larger images typically have more 
objects
+        if len(input_image_size) >= 2:
+            height, width = input_image_size[-2], input_image_size[-1]
+            total_pixels = height * width
+
+            # Base estimation per class based on image size
+            if total_pixels < 300000:  # Small images (< 300k pixels)
+                base_detections_per_class = min(50, max(10, total_pixels // 
2000))
+            elif total_pixels < 1000000:  # Medium images (< 1M pixels)
+                base_detections_per_class = min(100, max(25, total_pixels // 
3000))
+            else:  # Large images (>= 1M pixels)
+                base_detections_per_class = min(200, max(50, total_pixels // 
4000))
+
+            # Scale down for many classes (more realistic for multi-class 
scenarios)
+            if batch_class > 20:
+                # For many classes, reduce per-class detections to avoid 
explosion
+                detections_per_class = min(base_detections_per_class, 50)
+            else:
+                detections_per_class = base_detections_per_class
+        else:
+            detections_per_class = 50  # fallback
+    else:
+        # Fallback to class-based estimation
+        if batch_class == 1:
+            detections_per_class = 100  # Single class detection
+        elif batch_class <= 10:
+            detections_per_class = 50  # Small multi-class
+        else:
+            detections_per_class = 25  # Large multi-class (COCO-like)
+
+    return batch_class * detections_per_class
+
+
+def collect_selected_indices(
+    num_class,
+    selected_indices,
+    num_detections,
+    row_offsets,
+    ir,
+    max_output_boxes_per_class=None,
+    output_shape=None,
+    num_total_detections=None,
+    input_image_size=None,
+):
+    """Collect selected indices from the core NMS loop into one linear output
+    Parameters
+    ----------
+    num_class : int
+    selected_indices: tvm.te.Tensor
+        2-D tensor with shape (batch_size * num_classes, num_boxes), 
representing the indices
+        of selected boxes by the core NMS loop.
+    num_detections tvm.te.Tensor
+        1-D tensor with shape (batch_size * num_classes,), representing
+        the number of boxes selected by the core NMS loop, per batch and class
+    row_offsets tvm.te.Tensor
+        1-D tensor with shape (batch_size * num_classes,), this should be the 
exclusive scan
+        of num_detections
+    ir : function
+        A function to generate IR for CPU or GPU, see its usage in 
vision/nms.py and cuda/nms.py
+    Returns
+    -------
+    out : tvm.te.Tensor
+        The output is indices of size (batch_size * num_class* num_boxes , 3).
+        Rows of indices are ordered such that selected boxes from batch 0, 
class 0 come
+        first, in descending of scores, followed by boxes from batch 0, class 
1 etc.
+    """
+    batch_class, num_boxes = selected_indices.shape
+
+    if output_shape is not None:
+        return te.extern(
+            [output_shape],
+            [selected_indices, num_detections, row_offsets],
+            lambda ins, outs: ir(
+                num_class, ins[0], ins[1], ins[2], outs[0], 
max_output_boxes_per_class
+            ),
+            dtype=["int64"],
+            name="collect_indices",
+            tag="collect_indices",
+        )
+
+    # TODO: Implement dynamic trimming based on num_total_detections
+    if num_total_detections is not None:
+        if isinstance(max_output_boxes_per_class, int):
+            out_rows = batch_class * max_output_boxes_per_class
+        else:
+            # Smart fallback based on input image size and typical production 
scenarios
+            out_rows = _estimate_max_detections(batch_class, input_image_size)
+
+        return te.extern(
+            [(out_rows, 3)],
+            [selected_indices, num_detections, row_offsets],
+            lambda ins, outs: ir(
+                num_class, ins[0], ins[1], ins[2], outs[0], 
max_output_boxes_per_class
+            ),
+            dtype=["int64"],
+            name="collect_indices",
+            tag="collect_indices",
+        )
+
+    if isinstance(max_output_boxes_per_class, int):
+        out_rows = batch_class * max_output_boxes_per_class
+        return te.extern(
+            [(out_rows, 3)],
+            [selected_indices, num_detections, row_offsets],
+            lambda ins, outs: ir(
+                num_class, ins[0], ins[1], ins[2], outs[0], 
max_output_boxes_per_class
+            ),
+            dtype=["int64"],
+            name="collect_indices",
+            tag="collect_indices",
+        )
+
+    if isinstance(max_output_boxes_per_class, te.Tensor):
+        try:
+            if len(max_output_boxes_per_class.shape) == 0:
+                max_boxes_val = int(max_output_boxes_per_class.data.numpy())
+            elif (
+                len(max_output_boxes_per_class.shape) == 1
+                and max_output_boxes_per_class.shape[0] == 1
+            ):
+                max_boxes_val = int(max_output_boxes_per_class.data.numpy()[0])
+            else:
+                max_boxes_val = num_boxes
+        except (ValueError, IndexError, AttributeError):
+            max_boxes_val = num_boxes
+
+        out_rows = batch_class * max_boxes_val
+        return te.extern(
+            [(out_rows, 3)],
+            [selected_indices, num_detections, row_offsets],
+            lambda ins, outs: ir(
+                num_class, ins[0], ins[1], ins[2], outs[0], 
max_output_boxes_per_class
+            ),
+            dtype=["int64"],
+            name="collect_indices",
+            tag="collect_indices",
+        )
+
+    return te.extern(
+        [(batch_class * num_boxes, 3)],
+        [selected_indices, num_detections, row_offsets],
+        lambda ins, outs: ir(
+            num_class, ins[0], ins[1], ins[2], outs[0], 
max_output_boxes_per_class
+        ),
+        dtype=["int64"],
+        name="collect_indices",
+        tag="collect_indices",
+    )
+
+
+def collect_selected_indices_and_scores(
+    selected_indices, selected_scores, num_detections, row_offsets, 
num_total_detections, ir
+):
+    """Collect selected indices and scores from the core NMS loop into one 
linear output
+    Parameters
+    ----------
+    num_class : int
+    selected_indices: tvm.te.Tensor
+        2-D tensor with shape (batch_size * num_classes, num_boxes), 
representing the indices
+        of selected boxes by the core NMS loop.
+    selected_indices: tvm.te.Tensor
+        2-D tensor with shape (batch_size * num_classes, num_boxes), 
representing the scores
+        of selected boxes by the core NMS loop.
+    num_detections tvm.te.Tensor
+        2-D tensor with shape (batch_size, num_classes), representing
+        the number of boxes selected by the core NMS loop, per batch and class
+    row_offsets tvm.te.Tensor
+        2-D tensor with shape (batch_size, num_classes), this should be the 
exclusive scan
+        of num_detections along axis 1
+    ir : function
+        A function to generate IR for CPU or GPU, see its usage in 
vision/nms.py and cuda/nms.py
+    Returns
+    -------
+    out : [tvm.te.Tensor, tvm.te.Tensor]
+        The output is two tensors. The first is indices of size
+        (batch_size, num_class* num_boxes, 2), and the second is scores of size
+        (batch_size, num_class* num_boxes).
+    """
+    batch_size, num_class = row_offsets.shape
+    num_boxes = selected_indices.shape[1]
+    return te.extern(
+        [(batch_size, num_class * num_boxes, 2), (batch_size, num_class * 
num_boxes)],
+        [selected_indices, selected_scores, num_detections, row_offsets, 
num_total_detections],
+        lambda ins, outs: ir(ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], 
outs[1]),
+        dtype=["int64", "float32"],
+        name="collect_indices_and_scores",
+        tag="collect_indices_and_scores",
+    )
+
+
+def _all_class_nms_ir(
+    boxes,
+    sorted_scores,
+    sorted_indices,
+    valid_count,
+    batch_class,
+    num_class,
+    num_anchors,
+    iou_threshold,
+    max_output_size_per_class,
+    box_indices,
+    selected_scores,
+    num_valid_boxes,
+    nms_loop,
+    score_threshold=None,
+):
+    ib = tvm.tir.ir_builder.create()
+    boxes = ib.buffer_ptr(boxes)
+    sorted_scores = ib.buffer_ptr(sorted_scores)
+    sorted_indices = ib.buffer_ptr(sorted_indices)
+    valid_count = ib.buffer_ptr(valid_count)
+    box_indices = ib.buffer_ptr(box_indices)
+    num_valid_boxes = ib.buffer_ptr(num_valid_boxes)
+
+    if selected_scores is not None:
+        selected_scores = ib.buffer_ptr(selected_scores)
+
+    if isinstance(iou_threshold, float):
+        iou_threshold = tvm.tir.FloatImm("float32", iou_threshold)
+    elif isinstance(iou_threshold, te.Tensor):
+        if len(iou_threshold.shape) == 0:
+            iou_threshold = iou_threshold()
+        elif len(iou_threshold.shape) == 1 and iou_threshold.shape[0] == 1:
+            iou_threshold = iou_threshold[0]
+        else:
+            iou_threshold = tvm.tir.FloatImm("float32", 0.5)
+
+    if isinstance(max_output_size_per_class, int):
+        max_output_size_per_class = tvm.tir.const(max_output_size_per_class)
+    elif isinstance(max_output_size_per_class, te.Tensor):
+        if len(max_output_size_per_class.shape) == 0:
+            max_output_size_per_class = max_output_size_per_class()
+        elif len(max_output_size_per_class.shape) == 1 and 
max_output_size_per_class.shape[0] == 1:
+            # Use tensor indexing to get the first element
+            max_output_size_per_class = max_output_size_per_class[0]
+        else:
+            max_output_size_per_class = tvm.tir.const(1000)
+
+    def calc_overlap(i, j, k):
+        offset_j = sorted_indices[i, j] * 4
+        offset_k = sorted_indices[i, k] * 4
+        batch_id = i // num_class
+        base_bbox_idx = batch_id * num_anchors * 4
+        return calculate_overlap(
+            boxes,
+            base_bbox_idx + offset_j,
+            base_bbox_idx + offset_k,
+        )
+
+    def on_new_valid_box(ib, tid, num_current_valid_box, i, j):
+        with ib.if_scope(tid + 0 == 0):
+            box_indices[i, num_current_valid_box] = sorted_indices[i, j]
+
+            if selected_scores is not None:
+                selected_scores[i, num_current_valid_box] = sorted_scores[i, j]
+
+    def on_new_invalidated_box(*_):
+        pass
+
+    def needs_bbox_check(*_):
+        return tvm.tir.const(True)
+
+    return nms_loop(
+        ib,
+        batch_class,
+        tvm.tir.IntImm("int32", -1),  # top_k
+        iou_threshold,
+        max_output_size_per_class,
+        valid_count,
+        on_new_valid_box,
+        on_new_invalidated_box,
+        needs_bbox_check,
+        calc_overlap,
+        sorted_scores,
+        num_valid_boxes,
+        score_threshold,
+    )
+
+
+def run_all_class_nms(
+    boxes,
+    sorted_scores,
+    sorted_indices,
+    valid_count,
+    max_output_size_per_class,
+    iou_threshold,
+    nms_loop,
+    return_scores=False,
+    score_threshold=None,
+):
+    """The core all class NMS routine
+    Parameters
+    ----------
+    boxes : tvm.te.Tensor
+        3-D tensor with shape (batch_size, num_boxes, 4)
+    sorted_scores: tvm.te.Tensor
+        2-D tensor with shape (batch_size * num_classes, num_boxes)
+        One of the outputs from argsort
+    sorted_indices: tvm.te.Tensor
+        2-D tensor with shape (batch_size * num_classes, num_boxes)
+        The other output from argsort
+    valid_count: tvm.te.Tensor
+        1-D tensor with shape (batch_size * num_classes,), representing
+        the number of boxes whose score is above score_threshold, per batch 
and class
+    max_output_boxes_per_class : int or tvm.te.Tensor, optional
+        The maxinum number of output selected boxes per class
+    iou_threshold : float or tvm.te.Tensor, optionaIl
+        IoU test threshold
+    nms_loop : function
+        A core NMS loop, see its usage in vision/nms.py and cuda/nms.py
+    return_scores : bool, optional
+        Whether or not to return selected scores, needed by the tensorflow 
output format.
+    Returns
+    -------
+    out : a list of tvm.te.Tensor
+        The output is three tensors, the first and second are indices and 
scores of size
+        (batch_size * num_class, num_boxes), and the third is a tensor
+        num_selected_boxes of shape (batch_size * num_class,) representing the 
total number of
+        selected boxes per batch and class. If return_scores is False, the 
second output is
+        None.
+    """
+    batch, num_boxes, _ = boxes.shape
+    batch_class = sorted_scores.shape[0]
+    num_class = batch_class // batch
+
+    if return_scores is False:
+        all_class_num0_buf = tvm.tir.decl_buffer(
+            (batch_class, num_boxes), "int32", "all_class_nms0", 
data_alignment=8
+        )
+        all_class_num1_buf = tvm.tir.decl_buffer(
+            (batch_class,), "int32", "all_class_nms1", data_alignment=8
+        )
+        extern_inputs = [boxes, sorted_scores, sorted_indices, valid_count]
+        if score_threshold is not None:
+            extern_inputs.append(score_threshold)
+
+        selected_indices, num_detections = te.extern(
+            [(batch_class, num_boxes), (batch_class,)],
+            extern_inputs,
+            lambda ins, outs: _all_class_nms_ir(
+                ins[0],  # boxes
+                ins[1],  # sorted_scores
+                ins[2],  # sorted_indices
+                ins[3],  # valid_count
+                batch_class,
+                num_class,
+                num_boxes,
+                iou_threshold,
+                max_output_size_per_class,
+                outs[0],  # box_indices
+                None,  # scores
+                outs[1],  # num_selected_boxes
+                nms_loop,
+                ins[4] if score_threshold is not None else None,  # 
score_threshold
+            ),
+            out_buffers=[all_class_num0_buf, all_class_num1_buf],
+            dtype=["int32", "int32"],
+            name="all_class_nms",
+            tag="all_class_nms",
+        )
+        return selected_indices, None, num_detections
+
+    extern_inputs = [boxes, sorted_scores, sorted_indices, valid_count]
+    if score_threshold is not None:
+        extern_inputs.append(score_threshold)
+
+    return te.extern(
+        [(batch_class, num_boxes), (batch_class, num_boxes), (batch_class,)],
+        extern_inputs,
+        lambda ins, outs: _all_class_nms_ir(
+            ins[0],  # boxes
+            ins[1],  # sorted_scores
+            ins[2],  # sorted_indices
+            ins[3],  # valid_count
+            batch_class,
+            num_class,
+            num_boxes,
+            iou_threshold,
+            max_output_size_per_class,
+            outs[0],  # box_indices
+            outs[1],  # selected scores
+            outs[2],  # num_selected_boxes
+            nms_loop,
+            ins[4] if score_threshold is not None else None,  # score_threshold
+        ),
+        dtype=["int32", "float32", "int32"],
+        name="all_class_nms",
+        tag="all_class_nms",
+    )
diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h
index bb4098ae82..f09dcb7f82 100644
--- a/src/relax/ir/emit_te.h
+++ b/src/relax/ir/emit_te.h
@@ -51,6 +51,10 @@ class RXPlaceholderOpNode : public te::PlaceholderOpNode {
         .def_ro("shape", &RXPlaceholderOpNode::shape)
         .def_ro("dtype", &RXPlaceholderOpNode::dtype);
   }
+
+  // FFI system configuration for structural equality and hashing
+  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = 
kTVMFFISEqHashKindTreeNode;
+
   TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.TEPlaceholderOp", 
RXPlaceholderOpNode,
                                     te::PlaceholderOpNode);
 };
diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc
new file mode 100644
index 0000000000..2a1ad8f40a
--- /dev/null
+++ b/src/relax/op/vision/nms.cc
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "nms.h"
+
+#include <tvm/ffi/reflection/registry.h>
+#include <tvm/ffi/string.h>
+#include <tvm/ir/attrs.h>
+#include <tvm/ir/expr.h>
+#include <tvm/ir/op.h>
+#include <tvm/relax/attrs/vision.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+TVM_FFI_STATIC_INIT_BLOCK() { 
AllClassNonMaximumSuppressionAttrs::RegisterReflection(); }
+
+/* relax.vision.all_class_non_max_suppression */
+
+Expr all_class_non_max_suppression(Expr boxes, Expr scores, Expr 
max_output_boxes_per_class,
+                                   Expr iou_threshold, Expr score_threshold,
+                                   ffi::String output_format) {
+  auto attrs = tvm::ffi::make_object<AllClassNonMaximumSuppressionAttrs>();
+  attrs->output_format = output_format;
+
+  static const Op& op = Op::Get("relax.vision.all_class_non_max_suppression");
+  return Call(op,
+              {std::move(boxes), std::move(scores), 
std::move(max_output_boxes_per_class),
+               std::move(iou_threshold), std::move(score_threshold)},
+              Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def("relax.op.vision.all_class_non_max_suppression",
+                        all_class_non_max_suppression);
+}
+
+StructInfo InferStructInfoAllClassNMS(const Call& call, const BlockBuilder& 
ctx) {
+  tvm::ffi::Array<TensorStructInfo> input_sinfo = 
GetInputTensorStructInfo(call, ctx);
+  const auto boxes_sinfo = input_sinfo[0];
+  const auto scores_sinfo = input_sinfo[1];
+  ICHECK(!boxes_sinfo->IsUnknownNdim()) << "Only support known ndim";
+  ICHECK(!scores_sinfo->IsUnknownNdim()) << "Only support known ndim";
+  ICHECK_EQ(boxes_sinfo->ndim, 3) << "AllClassNMS input boxes should be 3-D.";
+  ICHECK_EQ(scores_sinfo->ndim, 3) << "AllClassNMS input scores count should 
be 3-D.";
+
+  const auto batch = boxes_sinfo->shape.as<ShapeExprNode>()->values[0];
+  const auto num_classes = scores_sinfo->shape.as<ShapeExprNode>()->values[1];
+  const auto num_boxes = boxes_sinfo->shape.as<ShapeExprNode>()->values[1];
+
+  auto vdev = input_sinfo[0]->vdevice;
+  const auto* attrs = call->attrs.as<AllClassNonMaximumSuppressionAttrs>();
+  if (attrs->output_format == "onnx") {
+    auto vdev = input_sinfo[0]->vdevice;
+    auto num_total_boxes = batch * num_classes * num_boxes;
+    tvm::ffi::Array<PrimExpr> oshape_values = {num_total_boxes, 3};
+    ShapeExpr oshape(oshape_values);
+    tvm::ffi::Array<PrimExpr> counts_values = {1};
+    ShapeExpr counts_shape(counts_values);
+    tvm::ffi::Array<StructInfo> fields = {TensorStructInfo(oshape, 
DataType::Int(64), vdev),
+                                          TensorStructInfo(counts_shape, 
DataType::Int(64), vdev)};
+    return TupleStructInfo(fields);
+  }
+
+  auto num_total_boxes_per_batch = num_classes * num_boxes;
+  tvm::ffi::Array<PrimExpr> indices_values = {batch, 
num_total_boxes_per_batch, 2};
+  ShapeExpr indices_shape(indices_values);
+  tvm::ffi::Array<PrimExpr> scores_values = {batch, num_total_boxes_per_batch};
+  ShapeExpr scores_shape(scores_values);
+  tvm::ffi::Array<PrimExpr> counts_values = {batch};
+  ShapeExpr counts_shape(counts_values);
+  tvm::ffi::Array<StructInfo> fields = {TensorStructInfo(indices_shape, 
DataType::Int(64), vdev),
+                                        TensorStructInfo(scores_shape, 
DataType::Float(32), vdev),
+                                        TensorStructInfo(counts_shape, 
DataType::Int(64), vdev)};
+  return TupleStructInfo(fields);
+}
+
+TVM_REGISTER_OP("relax.vision.all_class_non_max_suppression")
+    .set_attrs_type<AllClassNonMaximumSuppressionAttrs>()
+    .set_num_inputs(5)
+    .add_argument("boxes", "Tensor", "The input boxes in the format [batch, 
num_boxes, 4].")
+    .add_argument("scores", "Tensor",
+                  "Scores for each box and class in the format [batch, 
num_classes, num_boxes].")
+    .add_argument("max_output_boxes_per_class", "Tensor",
+                  "The maximum number of output boxes per class.")
+    .add_argument("iou_threshold", "Tensor", "The IoU threshold for box the 
overlap test.")
+    .add_argument("score_threshold", "Tensor",
+                  "The score threshold to filter out low score boxes early.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAllClassNMS)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/vision/nms.h b/src/relax/op/vision/nms.h
new file mode 100644
index 0000000000..c86bf98c94
--- /dev/null
+++ b/src/relax/op/vision/nms.h
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file nms.h
+ * \brief The functions to make Relax Non-maximum suppression operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_VISION_NMS_H_
+#define TVM_RELAX_OP_VISION_NMS_H_
+
+#include <tvm/ffi/string.h>
+#include <tvm/relax/attrs/vision.h>
+#include <tvm/runtime/object.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Compute All Class NonMaximumSuppression. */
+Expr all_class_non_max_suppression(Expr boxes, Expr scores, Expr 
max_output_boxes_per_class,
+                                   Expr iou_threshold, Expr score_threshold,
+                                   ffi::String output_format);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_OP_VISION_NMS_H_
diff --git a/src/te/operation/create_primfunc.cc 
b/src/te/operation/create_primfunc.cc
index 24c16ab268..fa84ab3863 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -650,7 +650,10 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp& 
extern_op, CreateFuncInfo* inf
   // reads/writes filled in.
 
   BufferSubstituter substituter(var_map, input_buffer_map);
-  Stmt body = substituter(extern_op->body);
+  Stmt substituted_body = substituter(extern_op->body);
+
+  ProducerToBufferTransformer transformer(info->tensor2buffers);
+  Stmt body = transformer(substituted_body);
 
   // Step 4. Generate opaque block as body.
   return BlockRealize(/*iter_values=*/{},
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index d2f5a65593..e4960e5b1a 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -3230,6 +3230,7 @@ def test_shape_dim_string_expression_graph_div_1():
                 gv: R.Tensor((A, B, A // B), dtype="float32") = x
                 R.output(gv)
             return gv
+
     # fmt: on
 
     tvm.ir.assert_structural_equal(tvm_model, Expected)
@@ -3269,5 +3270,430 @@ def test_shape_dim_string_expression_graph_div_2():
     tvm.ir.assert_structural_equal(tvm_model, Expected)
 
 
+def test_nms():
+    """Test NonMaxSuppression operator conversion using our AllClassNMS 
implementation."""
+    nms_node = helper.make_node(
+        "NonMaxSuppression",
+        ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold", 
"score_threshold"],
+        ["selected_indices"],
+        center_point_box=0,
+    )
+
+    boxes_shape = [1, 5, 4]  # batch_size, num_boxes, 4
+    scores_shape = [1, 2, 5]  # batch_size, num_classes, num_boxes
+
+    graph = helper.make_graph(
+        [nms_node],
+        "nms_test",
+        inputs=[
+            helper.make_tensor_value_info("boxes", TensorProto.FLOAT, 
boxes_shape),
+            helper.make_tensor_value_info("scores", TensorProto.FLOAT, 
scores_shape),
+        ],
+        initializer=[
+            helper.make_tensor("max_output_boxes_per_class", 
TensorProto.INT64, [1], [3]),
+            helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], [0.5]),
+            helper.make_tensor("score_threshold", TensorProto.FLOAT, [1], 
[0.1]),
+        ],
+        outputs=[helper.make_tensor_value_info("selected_indices", 
TensorProto.INT64, [0, 3])],
+    )
+
+    model = helper.make_model(graph, producer_name="nms_test")
+    model.opset_import[0].version = 11
+
+    # Use deterministic random inputs for consistent testing
+    bg = np.random.MT19937(0)
+    rg = np.random.Generator(bg)
+    boxes = rg.standard_normal(size=boxes_shape).astype(np.float32)
+    scores = rg.standard_normal(size=scores_shape).astype(np.float32)
+    inputs = {"boxes": boxes, "scores": scores}
+
+    # Run ONNX Runtime
+    ort_session = onnxruntime.InferenceSession(
+        model.SerializeToString(), providers=["CPUExecutionProvider"]
+    )
+    ort_output = ort_session.run([], inputs)
+
+    # Run TVM
+    tvm_model = from_onnx(model, opset=11, keep_params_in_input=True)
+    tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
+    tvm_model = relax.transform.LegalizeOps()(tvm_model)
+    tvm_model, params = relax.frontend.detach_params(tvm_model)
+
+    with tvm.transform.PassContext(opt_level=3):
+        ex = tvm.compile(tvm_model, target="llvm")
+        vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    input_list = [
+        inputs[key.name_hint] for key in tvm_model["main"].params if 
key.name_hint in inputs
+    ]
+    if params:
+        input_list += params["main"]
+
+    vm.set_input("main", *input_list)
+    vm.invoke_stateful("main")
+    tvm_output = vm.get_outputs("main")
+
+    if isinstance(tvm_output, (list, tuple)):
+        tvm_selected = tvm_output[0].numpy()
+    else:
+        tvm_selected = tvm_output.numpy()
+    ort_selected = ort_output[0]
+
+    min_rows = min(tvm_selected.shape[0], ort_selected.shape[0])
+    if min_rows > 0:
+        tvm.testing.assert_allclose(
+            tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5, 
atol=1e-5
+        )
+
+
+def test_nms_algorithm_correctness():
+    """Test NMS algorithm correctness with fixed data to verify suppression 
logic."""
+    nms_node = helper.make_node(
+        "NonMaxSuppression",
+        ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold", 
"score_threshold"],
+        ["selected_indices"],
+        center_point_box=0,
+    )
+
+    # Create fixed test data with known expected results
+    # Boxes: [x1, y1, x2, y2] format
+    boxes_data = np.array(
+        [
+            [
+                [0.0, 0.0, 1.0, 1.0],  # Box 0: [0,0,1,1] - should be selected
+                [
+                    0.5,
+                    0.5,
+                    1.5,
+                    1.5,
+                ],  # Box 1: [0.5,0.5,1.5,1.5] - overlaps with box 0, should 
be suppressed
+                [2.0, 2.0, 3.0, 3.0],
+            ]
+        ],  # Box 2: [2,2,3,3] - no overlap, should be selected
+        dtype=np.float32,
+    )
+
+    # Scores: higher score = better
+    scores_data = np.array(
+        [
+            [[0.9, 0.8, 0.7], [0.6, 0.5, 0.4]]  # Class 0: [0.9, 0.8, 0.7] - 
box 0 has highest score
+        ],  # Class 1: [0.6, 0.5, 0.4] - box 0 has highest score
+        dtype=np.float32,
+    )
+
+    boxes_shape = [1, 3, 4]  # batch_size, num_boxes, 4
+    scores_shape = [1, 2, 3]  # batch_size, num_classes, num_boxes
+
+    graph = helper.make_graph(
+        [nms_node],
+        "nms_test_correctness",
+        inputs=[
+            helper.make_tensor_value_info("boxes", TensorProto.FLOAT, 
boxes_shape),
+            helper.make_tensor_value_info("scores", TensorProto.FLOAT, 
scores_shape),
+        ],
+        initializer=[
+            helper.make_tensor(
+                "max_output_boxes_per_class", TensorProto.INT64, [1], [2]
+            ),  # Only 2 boxes per class
+            helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], 
[0.5]),  # IoU threshold 0.5
+            helper.make_tensor(
+                "score_threshold", TensorProto.FLOAT, [1], [0.1]
+            ),  # Score threshold 0.1
+        ],
+        outputs=[helper.make_tensor_value_info("selected_indices", 
TensorProto.INT64, [4, 3])],
+    )
+
+    model = helper.make_model(graph, producer_name="nms_test_correctness")
+
+    # Use fixed inputs instead of random
+    inputs = {
+        "boxes": boxes_data,
+        "scores": scores_data,
+    }
+
+    check_correctness(model, inputs=inputs, opset=11)
+
+
+def test_nms_iou_suppression():
+    """Test that NMS correctly suppresses overlapping boxes based on IoU 
threshold."""
+    nms_node = helper.make_node(
+        "NonMaxSuppression",
+        ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold", 
"score_threshold"],
+        ["selected_indices"],
+        center_point_box=0,
+    )
+
+    # Create overlapping boxes where box 0 has higher score and should be kept
+    boxes_data = np.array(
+        [
+            [
+                [0.0, 0.0, 1.0, 1.0],  # Box 0: [0,0,1,1] - highest score
+                [
+                    0.1,
+                    0.1,
+                    1.1,
+                    1.1,
+                ],  # Box 1: [0.1,0.1,1.1,1.1] - high IoU with box 0, should 
be suppressed
+                [2.0, 2.0, 3.0, 3.0],
+            ]
+        ],  # Box 2: [2,2,3,3] - no overlap, should be kept
+        dtype=np.float32,
+    )
+
+    # Box 0 has highest score, Box 1 should be suppressed due to IoU with box 0
+    scores_data = np.array([[[0.9, 0.8, 0.7]]], dtype=np.float32)
+
+    boxes_shape = [1, 3, 4]
+    scores_shape = [1, 1, 3]
+
+    graph = helper.make_graph(
+        [nms_node],
+        "nms_test_iou_suppression",
+        inputs=[
+            helper.make_tensor_value_info("boxes", TensorProto.FLOAT, 
boxes_shape),
+            helper.make_tensor_value_info("scores", TensorProto.FLOAT, 
scores_shape),
+        ],
+        initializer=[
+            helper.make_tensor("max_output_boxes_per_class", 
TensorProto.INT64, [1], [2]),
+            helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], 
[0.5]),  # IoU threshold 0.5
+            helper.make_tensor("score_threshold", TensorProto.FLOAT, [1], 
[0.1]),
+        ],
+        outputs=[helper.make_tensor_value_info("selected_indices", 
TensorProto.INT64, [2, 3])],
+    )
+
+    model = helper.make_model(graph, producer_name="nms_test_iou_suppression")
+    model.opset_import[0].version = 11
+
+    inputs = {
+        "boxes": boxes_data,
+        "scores": scores_data,
+    }
+
+    # Run ONNX Runtime
+    ort_session = onnxruntime.InferenceSession(
+        model.SerializeToString(), providers=["CPUExecutionProvider"]
+    )
+    ort_output = ort_session.run([], inputs)
+
+    # Run TVM
+    tvm_model = from_onnx(model, opset=11, keep_params_in_input=True)
+    tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
+    tvm_model = relax.transform.LegalizeOps()(tvm_model)
+    tvm_model, params = relax.frontend.detach_params(tvm_model)
+
+    with tvm.transform.PassContext(opt_level=3):
+        ex = tvm.compile(tvm_model, target="llvm")
+        vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    input_list = [
+        inputs[key.name_hint] for key in tvm_model["main"].params if 
key.name_hint in inputs
+    ]
+    if params:
+        input_list += params["main"]
+
+    vm.set_input("main", *input_list)
+    vm.invoke_stateful("main")
+    tvm_output = vm.get_outputs("main")
+
+    # Custom NMS output comparison
+    if isinstance(tvm_output, (list, tuple)):
+        tvm_selected = tvm_output[0].numpy()
+    else:
+        tvm_selected = tvm_output.numpy()
+    ort_selected = ort_output[0]
+
+    # For NMS, compare only the valid rows
+    min_rows = min(tvm_selected.shape[0], ort_selected.shape[0])
+    if min_rows > 0:
+        tvm.testing.assert_allclose(
+            tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5, 
atol=1e-5
+        )
+
+
+def test_nms_max_boxes_limit():
+    """Test that NMS correctly limits the number of boxes per class."""
+    nms_node = helper.make_node(
+        "NonMaxSuppression",
+        ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold", 
"score_threshold"],
+        ["selected_indices"],
+        center_point_box=0,
+    )
+
+    # Create data with 4 boxes, but limit to 2 per class
+    boxes_data = np.array(
+        [
+            [
+                [0.0, 0.0, 1.0, 1.0],  # Box 0
+                [2.0, 0.0, 3.0, 1.0],  # Box 1
+                [0.0, 2.0, 1.0, 3.0],  # Box 2
+                [2.0, 2.0, 3.0, 3.0],
+            ]
+        ],  # Box 3
+        dtype=np.float32,
+    )
+
+    # All boxes have different scores
+    scores_data = np.array([[[0.9, 0.8, 0.7, 0.6]]], dtype=np.float32)
+
+    boxes_shape = [1, 4, 4]
+    scores_shape = [1, 1, 4]
+
+    graph = helper.make_graph(
+        [nms_node],
+        "nms_test_max_boxes_limit",
+        inputs=[
+            helper.make_tensor_value_info("boxes", TensorProto.FLOAT, 
boxes_shape),
+            helper.make_tensor_value_info("scores", TensorProto.FLOAT, 
scores_shape),
+        ],
+        initializer=[
+            helper.make_tensor(
+                "max_output_boxes_per_class", TensorProto.INT64, [1], [2]
+            ),  # Limit to 2 boxes
+            helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], 
[0.1]),  # Low IoU threshold
+            helper.make_tensor("score_threshold", TensorProto.FLOAT, [1], 
[0.1]),
+        ],
+        outputs=[helper.make_tensor_value_info("selected_indices", 
TensorProto.INT64, [2, 3])],
+    )
+
+    model = helper.make_model(graph, producer_name="nms_test_max_boxes_limit")
+    model.opset_import[0].version = 11
+
+    inputs = {
+        "boxes": boxes_data,
+        "scores": scores_data,
+    }
+
+    # Run ONNX Runtime
+    ort_session = onnxruntime.InferenceSession(
+        model.SerializeToString(), providers=["CPUExecutionProvider"]
+    )
+    ort_output = ort_session.run([], inputs)
+
+    # Run TVM
+    tvm_model = from_onnx(model, opset=11, keep_params_in_input=True)
+    tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
+    tvm_model = relax.transform.LegalizeOps()(tvm_model)
+    tvm_model, params = relax.frontend.detach_params(tvm_model)
+
+    with tvm.transform.PassContext(opt_level=3):
+        ex = tvm.compile(tvm_model, target="llvm")
+        vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    input_list = [
+        inputs[key.name_hint] for key in tvm_model["main"].params if 
key.name_hint in inputs
+    ]
+    if params:
+        input_list += params["main"]
+
+    vm.set_input("main", *input_list)
+    vm.invoke_stateful("main")
+    tvm_output = vm.get_outputs("main")
+
+    # Custom NMS output comparison
+    if isinstance(tvm_output, (list, tuple)):
+        tvm_selected = tvm_output[0].numpy()
+    else:
+        tvm_selected = tvm_output.numpy()
+    ort_selected = ort_output[0]
+
+    # For NMS, compare only the valid rows
+    min_rows = min(tvm_selected.shape[0], ort_selected.shape[0])
+    if min_rows > 0:
+        tvm.testing.assert_allclose(
+            tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5, 
atol=1e-5
+        )
+
+
+def test_nms_score_threshold():
+    """Test that NMS correctly filters boxes based on score threshold.
+
+    Note: This test uses a low score threshold (0.05) to ensure both TVM and 
ONNX Runtime
+    output the same fixed shape [3,3], allowing use of the standard 
check_correctness function.
+    """
+    nms_node = helper.make_node(
+        "NonMaxSuppression",
+        ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold", 
"score_threshold"],
+        ["selected_indices"],
+        center_point_box=0,
+    )
+
+    # Create data with varying scores - ensure we get exactly 3 boxes after NMS
+    boxes_data = np.array(
+        [
+            [[0.0, 0.0, 1.0, 1.0], [2.0, 0.0, 3.0, 1.0], [0.0, 2.0, 1.0, 3.0]] 
 # Box 0  # Box 1
+        ],  # Box 2
+        dtype=np.float32,
+    )
+
+    # Scores: 0.9, 0.3, 0.1 - adjust score threshold to get exactly 3 boxes
+    scores_data = np.array([[[0.9, 0.3, 0.1]]], dtype=np.float32)
+
+    boxes_shape = [1, 3, 4]
+    scores_shape = [1, 1, 3]
+
+    graph = helper.make_graph(
+        [nms_node],
+        "nms_test_score_threshold",
+        inputs=[
+            helper.make_tensor_value_info("boxes", TensorProto.FLOAT, 
boxes_shape),
+            helper.make_tensor_value_info("scores", TensorProto.FLOAT, 
scores_shape),
+        ],
+        initializer=[
+            helper.make_tensor("max_output_boxes_per_class", 
TensorProto.INT64, [1], [3]),
+            helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], [0.1]),
+            helper.make_tensor("score_threshold", TensorProto.FLOAT, [1], 
[0.05]),
+        ],
+        outputs=[helper.make_tensor_value_info("selected_indices", 
TensorProto.INT64, [3, 3])],
+    )
+
+    model = helper.make_model(graph, producer_name="nms_test_score_threshold")
+    model.opset_import[0].version = 11
+
+    inputs = {
+        "boxes": boxes_data,
+        "scores": scores_data,
+    }
+
+    # Run ONNX Runtime
+    ort_session = onnxruntime.InferenceSession(
+        model.SerializeToString(), providers=["CPUExecutionProvider"]
+    )
+    ort_output = ort_session.run([], inputs)
+
+    # Run TVM
+    tvm_model = from_onnx(model, opset=11, keep_params_in_input=True)
+    tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
+    tvm_model = relax.transform.LegalizeOps()(tvm_model)
+    tvm_model, params = relax.frontend.detach_params(tvm_model)
+
+    with tvm.transform.PassContext(opt_level=3):
+        ex = tvm.compile(tvm_model, target="llvm")
+        vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    input_list = [
+        inputs[key.name_hint] for key in tvm_model["main"].params if 
key.name_hint in inputs
+    ]
+    if params:
+        input_list += params["main"]
+
+    vm.set_input("main", *input_list)
+    vm.invoke_stateful("main")
+    tvm_output = vm.get_outputs("main")
+
+    # Custom NMS output comparison
+    if isinstance(tvm_output, (list, tuple)):
+        tvm_selected = tvm_output[0].numpy()
+    else:
+        tvm_selected = tvm_output.numpy()
+    ort_selected = ort_output[0]
+
+    # For NMS, compare only the valid rows
+    min_rows = min(tvm_selected.shape[0], ort_selected.shape[0])
+    if min_rows > 0:
+        tvm.testing.assert_allclose(
+            tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5, 
atol=1e-5
+        )
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_op_vision.py 
b/tests/python/relax/test_op_vision.py
new file mode 100644
index 0000000000..97145a53ff
--- /dev/null
+++ b/tests/python/relax/test_op_vision.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax, tir
+from tvm import TVMError
+from tvm.ir import Op, VDevice
+from tvm.script import relax as R
+
+
+def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
+    ret = bb.normalize(call)
+    tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+
+
+def test_all_class_non_max_suppression_infer_struct_info():
+    bb = relax.BlockBuilder()
+    batch_size, num_classes, num_boxes = 10, 8, 5
+    boxes = relax.Var("boxes", R.Tensor((batch_size, num_boxes, 4), "float32"))
+    scores = relax.Var("scores", R.Tensor((batch_size, num_classes, 
num_boxes), "float32"))
+    max_output_boxes_per_class = relax.const(10, "int64")
+    iou_threshold = relax.const(0.5, "float32")
+    score_threshold = relax.const(0.1, "float32")
+
+    _check_inference(
+        bb,
+        relax.op.vision.all_class_non_max_suppression(
+            boxes, scores, max_output_boxes_per_class, iou_threshold, 
score_threshold, "onnx"
+        ),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((batch_size * num_classes * num_boxes, 
3), "int64"),
+                relax.TensorStructInfo((1,), "int64"),
+            ]
+        ),
+    )
+
+
+def test_all_class_non_max_suppression_wrong_input_number():
+    bb = relax.BlockBuilder()
+    boxes = relax.Var("boxes", R.Tensor((1, 5, 4), "float32"))
+    scores = relax.Var("scores", R.Tensor((1, 3, 5), "float32"))
+
+    with pytest.raises(TVMError):
+        relax.op.vision.all_class_non_max_suppression(boxes, scores)
+
+
+def test_all_class_non_max_suppression_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    batch_size = tir.Var("batch_size", "int64")
+    num_classes = tir.Var("num_classes", "int64")
+    num_boxes = tir.Var("num_boxes", "int64")
+    boxes = relax.Var("boxes", R.Tensor((batch_size, num_boxes, 4), "float32"))
+    scores = relax.Var("scores", R.Tensor((batch_size, num_classes, 
num_boxes), "float32"))
+    max_output_boxes_per_class = relax.const(10, "int64")
+    iou_threshold = relax.const(0.5, "float32")
+    score_threshold = relax.const(0.1, "float32")
+
+    _check_inference(
+        bb,
+        relax.op.vision.all_class_non_max_suppression(
+            boxes, scores, max_output_boxes_per_class, iou_threshold, 
score_threshold, "onnx"
+        ),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((batch_size * num_classes * num_boxes, 
3), "int64"),
+                relax.TensorStructInfo((1,), "int64"),
+            ]
+        ),
+    )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py 
b/tests/python/relax/test_tvmscript_parser_op_vision.py
new file mode 100644
index 0000000000..66e0adac3d
--- /dev/null
+++ b/tests/python/relax/test_tvmscript_parser_op_vision.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, Union
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import IRModule, relax
+from tvm.script import relax as R
+
+
+def _check(
+    parsed: Union[relax.Function, IRModule],
+    expect: Optional[Union[relax.Function, IRModule]],
+):
+    test = parsed.script(show_meta=True)
+    roundtrip_mod = tvm.script.from_source(test)
+    tvm.ir.assert_structural_equal(parsed, roundtrip_mod)
+    if expect:
+        tvm.ir.assert_structural_equal(parsed, expect)
+
+
+def test_all_class_non_max_suppression():
+    @R.function
+    def foo(
+        boxes: R.Tensor((10, 5, 4), "float32"),
+        scores: R.Tensor((10, 8, 5), "float32"),
+        max_output_boxes_per_class: R.Tensor((), "int64"),
+        iou_threshold: R.Tensor((), "float32"),
+        score_threshold: R.Tensor((), "float32"),
+    ) -> R.Tuple(R.Tensor((400, 3), "int64"), R.Tensor((1,), "int64")):
+        gv: R.Tuple(
+            R.Tensor((400, 3), "int64"), R.Tensor((1,), "int64")
+        ) = R.vision.all_class_non_max_suppression(
+            boxes,
+            scores,
+            max_output_boxes_per_class,
+            iou_threshold,
+            score_threshold,
+            "onnx",
+        )
+        return gv
+
+    boxes = relax.Var("boxes", R.Tensor((10, 5, 4), "float32"))
+    scores = relax.Var("scores", R.Tensor((10, 8, 5), "float32"))
+    max_output_boxes_per_class = relax.Var("max_output_boxes_per_class", 
R.Tensor((), "int64"))
+    iou_threshold = relax.Var("iou_threshold", R.Tensor((), "float32"))
+    score_threshold = relax.Var("score_threshold", R.Tensor((), "float32"))
+
+    bb = relax.BlockBuilder()
+    with bb.function(
+        "foo", [boxes, scores, max_output_boxes_per_class, iou_threshold, 
score_threshold]
+    ):
+        gv = bb.emit(
+            relax.op.vision.all_class_non_max_suppression(
+                boxes, scores, max_output_boxes_per_class, iou_threshold, 
score_threshold, "onnx"
+            )
+        )
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to