This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 78b5ed068c [Relax] Implement dynamic output trimming for NMS (#18676)
78b5ed068c is described below

commit 78b5ed068cda945913e9ded5788f4818fdf67c15
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Thu Jan 22 19:43:14 2026 +0800

    [Relax] Implement dynamic output trimming for NMS (#18676)
    
    ## Why
    NMS operator returns fixed-size output with trailing garbage data,
    wasting memory and requiring manual trimming for ONNX
    compatibility.
    ## How
    - Add dynamic_strided_slice to trim NMS output to valid detections only
    - Build slice parameters using TE compute to avoid legalization issues
---
 python/tvm/relax/op/vision/nms.py                 |   8 +-
 python/tvm/relax/transform/legalize_ops/vision.py | 114 ++++++++++------------
 tests/python/relax/test_op_vision.py              |  98 ++++++++++++++++++-
 3 files changed, 146 insertions(+), 74 deletions(-)

diff --git a/python/tvm/relax/op/vision/nms.py 
b/python/tvm/relax/op/vision/nms.py
index 3714b00b01..4c50748bdb 100644
--- a/python/tvm/relax/op/vision/nms.py
+++ b/python/tvm/relax/op/vision/nms.py
@@ -54,12 +54,10 @@ def all_class_non_max_suppression(
         `num_total_detection` of shape `(1,)` representing the total number of 
selected
         boxes. The three values in `indices` encode batch, class, and box 
indices.
         Rows of `indices` are ordered such that selected boxes from batch 0, 
class 0 come
-        first, in descending of scores, followed by boxes from batch 0, class 
1 etc. Out of
-        `batch_size * num_class* num_boxes` rows of indices, only the first 
`num_total_detection`
-        rows are valid.
+        first, in descending of scores, followed by boxes from batch 0, class 
1 etc.
+        The output uses dynamic_strided_slice to trim to only valid detections,
+        so the first tensor has shape (num_total_detection, 3) containing only 
valid rows.
 
-        TODO: Implement true dynamic output shapes to match ONNX Runtime 
behavior exactly.
-        This would eliminate the need for manual trimming and improve memory 
efficiency.
         If `output_format` is "tensorflow", the output is three tensors, the 
first
         is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the 
second is `scores` of
         size `(batch_size, num_class * num_boxes)`, and the third is 
`num_total_detection` of size
diff --git a/python/tvm/relax/transform/legalize_ops/vision.py 
b/python/tvm/relax/transform/legalize_ops/vision.py
index f910f62cec..9511c13018 100644
--- a/python/tvm/relax/transform/legalize_ops/vision.py
+++ b/python/tvm/relax/transform/legalize_ops/vision.py
@@ -15,64 +15,27 @@
 # specific language governing permissions and limitations
 # under the License.
 """Default legalization function for vision network related operators."""
-from tvm import topi, te
-from tvm import relax
+from tvm import relax, te, tir, topi
+
 from ...block_builder import BlockBuilder
-from ...expr import Call, Expr
+from ...expr import Call, Expr, TupleGetItem
 from .common import register_legalize
 
 
-def _create_onnx_nms_te(boxes, scores, max_output_boxes_per_class, 
iou_threshold, score_threshold):
-    """Create a proper NMS implementation that follows the correct algorithm"""
-    scores_shape = list(scores.shape)
-    if len(scores_shape) == 3:
-        batch, num_classes, _ = scores_shape
-    elif len(scores_shape) == 2:
-        num_classes, _ = scores_shape
-        batch = 1
-    else:
-        raise ValueError(f"Unexpected scores shape: {scores_shape}")
-
-    if hasattr(max_output_boxes_per_class, "data"):
-        max_boxes = int(max_output_boxes_per_class.data.numpy())
-    else:
-        max_boxes = 3  # Default value
-
-    expected_detections = batch * num_classes * max_boxes
-
-    selected_indices_full, _ = topi.vision.all_class_non_max_suppression(
-        boxes, scores, max_output_boxes_per_class, iou_threshold, 
score_threshold, "onnx"
-    )
-
-    def slice_to_onnx_shape(data, expected_size):
-        def compute_element(i, j):
-            return tvm.tir.if_then_else(i < expected_size, data[i, j], 
tvm.tir.Cast("int64", 0))
-
-        return te.compute((expected_size, 3), compute_element, 
name="sliced_indices")
-
-    sliced_indices = slice_to_onnx_shape(selected_indices_full, 
expected_detections)
-
-    actual_detections = te.compute(
-        (1,), lambda i: tvm.tir.Cast("int64", expected_detections), 
name="actual_detections"
-    )
-
-    return [sliced_indices, actual_detections]
-
-
 @register_legalize("relax.vision.all_class_non_max_suppression")
 def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> 
Expr:
-    """Legalize all_class_non_max_suppression with fixed shape output.
-
-    Note: This implementation outputs fixed-size tensors with trailing garbage 
data.
-    Only the first `num_total_detection` rows contain valid data. Users should 
use
-    the `valid_count` tensor to determine how many rows are actually valid.
-
-    For complete ONNX compatibility, users can post-process the output:
-    ```python
-    selected_indices, valid_count = nms_output
-    actual_count = int(valid_count.numpy()[0])
-    valid_indices = selected_indices.numpy()[:actual_count, :]
-    ```
+    """Legalize all_class_non_max_suppression with dynamic output trimming.
+
+    This implementation uses dynamic_strided_slice to trim the NMS output to 
only
+    contain valid detections, improving memory efficiency and ONNX 
compatibility.
+
+    Returns
+    -------
+    result : Tuple[Tensor, Tensor]
+        A tuple of (trimmed_indices, num_total_detections) where:
+        - trimmed_indices: Tensor of shape (num_total_detections, 3) 
containing only
+          valid detection indices (batch_id, class_id, box_id)
+        - num_total_detections: Tensor of shape (1,) with the count of valid 
detections
     """
     boxes = call.args[0]
     scores = call.args[1]
@@ -105,16 +68,37 @@ def _all_class_non_max_suppression(block_builder: 
BlockBuilder, call: Call) -> E
         output_format,
     )
 
-    # TODO: Implement dynamic output trimming for better memory efficiency
-    # Current approach returns fixed-size output with trailing garbage data
-    # Future improvements could include:
-    # 1. Dynamic strided_slice based on num_total_detections
-    # 2. Custom Relax operator with true dynamic shapes
-    # 3. VM builtin functions for runtime shape adjustment
-    # 4. Symbolic shape inference in Relax IR
-    #
-    # For now, users should trim manually:
-    # actual_count = int(num_total_detections.numpy()[0])
-    # valid_indices = selected_indices.numpy()[:actual_count, :]
-
-    return nms_result
+    # Dynamic output trimming using dynamic_strided_slice
+    # Extract selected_indices and num_total_detections from the NMS result
+    selected_indices = block_builder.emit(TupleGetItem(nms_result, 0))
+    num_total_detections = block_builder.emit(TupleGetItem(nms_result, 1))
+
+    # Build slicing parameters using TE to avoid high-level Relax ops during 
legalization
+    def build_begin():
+        return te.compute((2,), lambda i: tir.const(0, "int64"), name="begin")
+
+    def build_strides():
+        return te.compute((2,), lambda i: tir.const(1, "int64"), 
name="strides")
+
+    def build_end(count_tensor):
+        # end = [count_tensor[0], 3]
+        def compute_end(i):
+            return tir.if_then_else(
+                i == 0,
+                tir.Cast("int64", count_tensor[0]),
+                tir.const(3, "int64"),
+            )
+
+        return te.compute((2,), compute_end, name="end")
+
+    begin = block_builder.call_te(build_begin)
+    strides = block_builder.call_te(build_strides)
+    end = block_builder.call_te(build_end, num_total_detections)
+
+    # Apply dynamic strided slice to trim to valid detections only
+    trimmed_indices = block_builder.emit(
+        relax.op.dynamic_strided_slice(selected_indices, begin, end, strides)
+    )
+
+    # Return trimmed indices along with num_total_detections for compatibility
+    return relax.Tuple([trimmed_indices, num_total_detections])
diff --git a/tests/python/relax/test_op_vision.py 
b/tests/python/relax/test_op_vision.py
index 97145a53ff..660b5d2772 100644
--- a/tests/python/relax/test_op_vision.py
+++ b/tests/python/relax/test_op_vision.py
@@ -15,12 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import numpy as np
 import pytest
+
 import tvm
 import tvm.testing
-from tvm import relax, tir
-from tvm import TVMError
-from tvm.ir import Op, VDevice
+from tvm import TVMError, relax, tir
+from tvm.relax.transform import LegalizeOps
 from tvm.script import relax as R
 
 
@@ -53,7 +54,6 @@ def test_all_class_non_max_suppression_infer_struct_info():
 
 
 def test_all_class_non_max_suppression_wrong_input_number():
-    bb = relax.BlockBuilder()
     boxes = relax.Var("boxes", R.Tensor((1, 5, 4), "float32"))
     scores = relax.Var("scores", R.Tensor((1, 3, 5), "float32"))
 
@@ -86,5 +86,95 @@ def 
test_all_class_non_max_suppression_infer_struct_info_shape_var():
     )
 
 
+def test_all_class_non_max_suppression_legalize_dynamic_trim():
+    @tvm.script.ir_module
+    class NMSModule:
+        @R.function
+        def main(
+            boxes: R.Tensor((1, 5, 4), "float32"),
+            scores: R.Tensor((1, 2, 5), "float32"),
+        ) -> R.Tuple(R.Tensor(dtype="int64", ndim=2), R.Tensor((1,), "int64")):
+            max_output_boxes_per_class = R.const(3, "int64")
+            iou_threshold = R.const(0.5, "float32")
+            score_threshold = R.const(0.1, "float32")
+            return R.vision.all_class_non_max_suppression(
+                boxes, scores, max_output_boxes_per_class, iou_threshold, 
score_threshold, "onnx"
+            )
+
+    mod = LegalizeOps()(NMSModule)
+
+    # Check legalized function has dynamic output (uses dynamic_strided_slice)
+    assert "dynamic_strided_slice" in str(mod)
+
+    ret_sinfo = mod["main"].ret_struct_info
+    tvm.ir.assert_structural_equal(
+        ret_sinfo,
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(ndim=2, dtype="int64"),
+                relax.TensorStructInfo((1,), "int64"),
+            ]
+        ),
+    )
+
+
+def test_all_class_non_max_suppression_legalize_e2e():
+    @tvm.script.ir_module
+    class NMSModule:
+        @R.function
+        def main(
+            boxes: R.Tensor((1, 5, 4), "float32"),
+            scores: R.Tensor((1, 2, 5), "float32"),
+        ) -> R.Tuple(R.Tensor(dtype="int64", ndim=2), R.Tensor((1,), "int64")):
+            max_output_boxes_per_class = R.const(3, "int64")
+            iou_threshold = R.const(0.5, "float32")
+            score_threshold = R.const(0.1, "float32")
+            return R.vision.all_class_non_max_suppression(
+                boxes, scores, max_output_boxes_per_class, iou_threshold, 
score_threshold, "onnx"
+            )
+
+    boxes_data = np.array(
+        [
+            [
+                [0.0, 0.0, 1.0, 1.0],
+                [0.1, 0.1, 1.1, 1.1],
+                [2.0, 2.0, 3.0, 3.0],
+                [4.0, 4.0, 5.0, 5.0],
+                [6.0, 6.0, 7.0, 7.0],
+            ]
+        ],
+        dtype=np.float32,
+    )
+    scores_data = np.array(
+        [[[0.9, 0.8, 0.7, 0.6, 0.5], [0.85, 0.75, 0.65, 0.55, 0.45]]],
+        dtype=np.float32,
+    )
+
+    mod = LegalizeOps()(NMSModule)
+
+    # Check struct info
+    tvm.ir.assert_structural_equal(
+        mod["main"].ret_struct_info,
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(ndim=2, dtype="int64"),
+                relax.TensorStructInfo((1,), "int64"),
+            ]
+        ),
+    )
+
+    # Check runtime execution
+    exe = tvm.compile(mod, target="llvm")
+    vm = relax.VirtualMachine(exe, tvm.cpu())
+    result = vm["main"](
+        tvm.runtime.tensor(boxes_data, tvm.cpu()),
+        tvm.runtime.tensor(scores_data, tvm.cpu()),
+    )
+
+    selected_indices = result[0].numpy()
+    num_total_detections = int(result[1].numpy()[0])
+    tvm.testing.assert_allclose(selected_indices.shape, (num_total_detections, 
3))
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to