This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bedfcb2b85 [Relax][ONNX] Set `max_output_boxes_per_class` default
value to 0 for NonMaxSuppression (#19547)
bedfcb2b85 is described below
commit bedfcb2b85d63fa75b1077d1b195b30f7d388e57
Author: Neo Chien <[email protected]>
AuthorDate: Wed May 13 12:30:41 2026 +0800
[Relax][ONNX] Set `max_output_boxes_per_class` default value to 0 for
NonMaxSuppression (#19547)
Hi Committers,
This PR is trying to fix issues #19544. Any suggestions would be
appreciated if you are available.
---------
Co-authored-by: cchung100m <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 8 ++--
tests/python/relax/test_frontend_onnx.py | 57 +++++++++++++++++++++++++
2 files changed, 61 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 560b644de8..3f25d2ff3b 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -4781,9 +4781,9 @@ class NonMaxSuppression(OnnxOpConverter):
_, param_value = params[1][var_name]
max_output_boxes_per_class = int(param_value.numpy().item())
else:
- max_output_boxes_per_class = 100 # Default value
+ max_output_boxes_per_class = 0 # Default value
else:
- max_output_boxes_per_class = 100 # Default value
+ max_output_boxes_per_class = 0 # Default value
if iou_threshold is not None and isinstance(iou_threshold,
relax.Constant):
iou_threshold = float(iou_threshold.data.numpy())
@@ -4870,9 +4870,9 @@ class AllClassNMS(OnnxOpConverter):
_, param_value = params[1][var_name]
max_output_boxes_per_class = int(param_value.numpy().item())
else:
- max_output_boxes_per_class = 100 # Default value
+ max_output_boxes_per_class = 0 # Default value
else:
- max_output_boxes_per_class = 100 # Default value
+ max_output_boxes_per_class = 0 # Default value
if iou_threshold is not None and isinstance(iou_threshold,
relax.Constant):
iou_threshold = float(iou_threshold.data.numpy())
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 7f1cecd1c9..0d1d9f2d7c 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4871,6 +4871,63 @@ def test_nms():
)
[email protected]("with_explicit_max", [False, True])
+def test_nms_max_output_boxes_per_class_zero(with_explicit_max: bool):
+ """ONNX default for max_output_boxes_per_class is 0, yielding empty
output."""
+ node_inputs = ["boxes", "scores"]
+ initializer = []
+ if with_explicit_max:
+ node_inputs.append("max_output_boxes_per_class")
+ initializer.append(
+ helper.make_tensor("max_output_boxes_per_class",
TensorProto.INT64, [1], [0])
+ )
+
+ nms_node = helper.make_node(
+ "NonMaxSuppression",
+ node_inputs,
+ ["selected_indices"],
+ center_point_box=0,
+ )
+
+ boxes_shape = [1, 4, 4]
+ scores_shape = [1, 1, 4]
+ graph = helper.make_graph(
+ [nms_node],
+ "nms_max_output_boxes_per_class_zero",
+ inputs=[
+ helper.make_tensor_value_info("boxes", TensorProto.FLOAT,
boxes_shape),
+ helper.make_tensor_value_info("scores", TensorProto.FLOAT,
scores_shape),
+ ],
+ initializer=initializer,
+ outputs=[helper.make_tensor_value_info("selected_indices",
TensorProto.INT64, [0, 3])],
+ )
+
+ model = helper.make_model(graph,
producer_name="nms_max_output_boxes_per_class_zero")
+ model.ir_version = 8
+ model.opset_import[0].version = 11
+
+ inputs = {
+ "boxes": np.array(
+ [
+ [
+ [0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.1, 1.0, 1.1],
+ [2.0, 2.0, 3.0, 3.0],
+ [2.0, 2.1, 3.0, 3.1],
+ ]
+ ],
+ dtype=np.float32,
+ ),
+ "scores": np.array([[[0.9, 0.8, 0.7, 0.6]]], dtype=np.float32),
+ }
+
+ check_correctness(model, inputs=inputs, opset=11)
+
+ tvm_out = run_in_tvm(model, inputs=inputs, opset=11)
+ tvm_selected = tvm_out[0].numpy() if isinstance(tvm_out, (list, tuple))
else tvm_out.numpy()
+ assert tvm_selected.shape == (0, 3)
+
+
def test_nms_algorithm_correctness():
"""Test NMS algorithm correctness with fixed data to verify suppression
logic."""
nms_node = helper.make_node(