This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 38eb79c63f [Relax][Vision] Add get_valid_counts and classic NMS
(#18943)
38eb79c63f is described below
commit 38eb79c63f1514402e75efdf71ab1868532115c8
Author: HoYi <[email protected]>
AuthorDate: Sun Mar 29 01:24:27 2026 +0800
[Relax][Vision] Add get_valid_counts and classic NMS (#18943)
## Summary
Add `relax.vision.get_valid_counts` and classic
`relax.vision.non_max_suppression` for object-detection post-processing
pipelines.
`get_valid_counts` performs score-based bounding box filtering and
compacts valid boxes to the front of each batch. Classic
`non_max_suppression` performs flexible IoU-based suppression on
filtered boxes, complementing existing `all_class_non_max_suppression`
for custom post-processing workflows.
This PR implements the Relax-level registration, legalization, TOPI
compute, and test coverage for both operators.
## Changes
**Relax op registration and legalization:**
- C++ op functions, FFI registration, and struct info inference for both
operators (`vision.h`, `vision.cc`)
- Python wrappers with Relax docstrings (`vision.py`)
- Legalization to `topi.vision.get_valid_counts` and
`topi.vision.non_max_suppression`
- Additional struct-info validation for `score_index`, `id_index`, and
`coord_start` when `elem_length` is statically known
**TOPI and testing:**
- Full TOPI implementation for `get_valid_counts`
- Reimplementation of classic `non_max_suppression` in TOPI
- NumPy reference implementations in `tvm.topi.testing` for both
operators
- Op-level tests for struct info inference, legalization, invalid
attribute ranges, and e2e numerical correctness
- Stronger legalization tests that verify both `relax.call_tir`
introduction and removal of the original Relax vision op
## Limitations
- Attribute range validation for `score_index`, `id_index`, and
`coord_start` is only enforced when the input `elem_length` is
statically known during struct-info inference.
- Classic `non_max_suppression` follows the existing Relax / TOPI API
shape and is intended for single-class or class-aware custom
post-processing flows, distinct from `all_class_non_max_suppression`.
## Validation
```bash
pytest tests/python/relax/test_op_vision.py -k "get_valid_counts" -v
pytest tests/python/relax/test_op_vision.py -k "test_nms_" -v
```
All related tests passed.
---
include/tvm/relax/attrs/vision.h | 60 ++
python/tvm/relax/op/__init__.py | 8 +-
python/tvm/relax/op/op_attrs.py | 10 +
python/tvm/relax/op/vision/nms.py | 114 ++-
python/tvm/relax/transform/legalize_ops/vision.py | 30 +
python/tvm/topi/testing/__init__.py | 2 +
python/tvm/topi/testing/get_valid_counts_python.py | 68 ++
python/tvm/topi/testing/nms_python.py | 146 ++++
python/tvm/topi/vision/nms.py | 503 +++++++++++-
python/tvm/topi/vision/nms_util.py | 4 +-
src/relax/op/vision/nms.cc | 244 +++++-
src/relax/op/vision/nms.h | 9 +
tests/python/relax/test_op_vision.py | 862 ++++++++++++++++++++-
.../relax/test_tvmscript_parser_op_vision.py | 132 ++++
14 files changed, 2166 insertions(+), 26 deletions(-)
diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h
index 4e3351bb90..69ce458e7e 100644
--- a/include/tvm/relax/attrs/vision.h
+++ b/include/tvm/relax/attrs/vision.h
@@ -73,6 +73,66 @@ struct ROIAlignAttrs : public
AttrsNodeReflAdapter<ROIAlignAttrs> {
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs",
ROIAlignAttrs, BaseAttrsNode);
}; // struct ROIAlignAttrs
+/*! \brief Attributes used in GetValidCounts operator */
+struct GetValidCountsAttrs : public AttrsNodeReflAdapter<GetValidCountsAttrs> {
+ double score_threshold;
+ int id_index;
+ int score_index;
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<GetValidCountsAttrs>()
+ .def_ro("score_threshold", &GetValidCountsAttrs::score_threshold,
+ "Lower limit of score for valid bounding boxes.")
+ .def_ro("id_index", &GetValidCountsAttrs::id_index,
+ "Index of the class categories, -1 to disable.")
+ .def_ro("score_index", &GetValidCountsAttrs::score_index,
+ "Index of the scores/confidence of boxes.");
+ }
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.GetValidCountsAttrs",
GetValidCountsAttrs,
+ BaseAttrsNode);
+}; // struct GetValidCountsAttrs
+
+/*! \brief Attributes used in NonMaximumSuppression operator */
+struct NonMaximumSuppressionAttrs
+ : public AttrsNodeReflAdapter<NonMaximumSuppressionAttrs> {
+ int max_output_size;
+ double iou_threshold;
+ bool force_suppress;
+ int top_k;
+ int coord_start;
+ int score_index;
+ int id_index;
+ bool return_indices;
+ bool invalid_to_bottom;
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<NonMaximumSuppressionAttrs>()
+ .def_ro("max_output_size",
&NonMaximumSuppressionAttrs::max_output_size,
+ "Max number of output valid boxes, -1 for no limit.")
+ .def_ro("iou_threshold", &NonMaximumSuppressionAttrs::iou_threshold,
+ "Non-maximum suppression IoU threshold.")
+ .def_ro("force_suppress", &NonMaximumSuppressionAttrs::force_suppress,
+ "Whether to suppress all detections regardless of class_id.")
+ .def_ro("top_k", &NonMaximumSuppressionAttrs::top_k,
+ "Keep maximum top k detections before nms, -1 for no limit.")
+ .def_ro("coord_start", &NonMaximumSuppressionAttrs::coord_start,
+ "Start index of the consecutive 4 coordinates.")
+ .def_ro("score_index", &NonMaximumSuppressionAttrs::score_index,
+ "Index of the scores/confidence of boxes.")
+ .def_ro("id_index", &NonMaximumSuppressionAttrs::id_index,
+ "Index of the class categories, -1 to disable.")
+ .def_ro("return_indices", &NonMaximumSuppressionAttrs::return_indices,
+ "Whether to return box indices in input data.")
+ .def_ro("invalid_to_bottom",
&NonMaximumSuppressionAttrs::invalid_to_bottom,
+ "Whether to move all valid bounding boxes to the top.");
+ }
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.NonMaximumSuppressionAttrs",
+ NonMaximumSuppressionAttrs, BaseAttrsNode);
+}; // struct NonMaximumSuppressionAttrs
+
+
/*! \brief Attributes for multibox_transform_loc (SSD / TFLite-style box
decode). */
struct MultiboxTransformLocAttrs : public
AttrsNodeReflAdapter<MultiboxTransformLocAttrs> {
bool clip;
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index ee1a2c2420..0b8dc4e7de 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -157,7 +157,13 @@ from .unary import (
tanh,
trunc,
)
-from .vision import all_class_non_max_suppression, multibox_transform_loc,
roi_align
+from .vision import (
+ all_class_non_max_suppression,
+ get_valid_counts,
+ multibox_transform_loc,
+ non_max_suppression,
+ roi_align,
+)
def _register_op_make():
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index e8c91f04b4..d85c439d3a 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -246,6 +246,16 @@ class AllClassNonMaximumSuppressionAttrs(Attrs):
"""Attributes for vision.all_class_non_max_suppression"""
+@tvm_ffi.register_object("relax.attrs.GetValidCountsAttrs")
+class GetValidCountsAttrs(Attrs):
+ """Attributes for vision.get_valid_counts"""
+
+
+@tvm_ffi.register_object("relax.attrs.NonMaximumSuppressionAttrs")
+class NonMaximumSuppressionAttrs(Attrs):
+ """Attributes for vision.non_max_suppression"""
+
+
@tvm_ffi.register_object("relax.attrs.ROIAlignAttrs")
class ROIAlignAttrs(Attrs):
"""Attributes for vision.roi_align"""
diff --git a/python/tvm/relax/op/vision/nms.py
b/python/tvm/relax/op/vision/nms.py
index 616c74ddf6..4eb3eb7f7a 100644
--- a/python/tvm/relax/op/vision/nms.py
+++ b/python/tvm/relax/op/vision/nms.py
@@ -14,9 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Non-maximum suppression operator"""
+"""Non-maximum suppression operators."""
-# from tvm import relax # Unused import
from . import _ffi_api
@@ -72,3 +71,114 @@ def all_class_non_max_suppression(
return _ffi_api.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold,
score_threshold, output_format
)
+
+
+def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
+ """Get valid count of bounding boxes given a score threshold.
+ Also moves valid boxes to the top of input data.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ 3-D tensor with shape [batch_size, num_anchors, elem_length].
+
+ score_threshold : float, optional
+ Lower limit of score for valid bounding boxes.
+
+ id_index : int, optional
+ Index of the class categories. Set to ``-1`` to disable the class-id
check.
+
+ score_index : int, optional
+ Index of the scores/confidence of boxes.
+
+ Returns
+ -------
+ out : relax.Expr
+ A tuple ``(valid_count, out_tensor, out_indices)`` where
``valid_count``
+ has shape ``[batch_size]``, ``out_tensor`` has shape
+ ``[batch_size, num_anchors, elem_length]``, and ``out_indices`` has
shape
+ ``[batch_size, num_anchors]``.
+ """
+ return _ffi_api.get_valid_counts(data, score_threshold, id_index,
score_index)
+
+
+def non_max_suppression(
+ data,
+ valid_count,
+ indices,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+):
+ """Non-maximum suppression operator for object detection.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ 3-D tensor with shape [batch_size, num_anchors, elem_length].
+
+ valid_count : relax.Expr
+ 1-D tensor for valid number of boxes.
+
+ indices : relax.Expr
+ 2-D tensor with shape [batch_size, num_anchors].
+
+ max_output_size : int, optional
+ Max number of output valid boxes, -1 for no limit.
+
+ iou_threshold : float, optional
+ Non-maximum suppression IoU threshold.
+
+ force_suppress : bool, optional
+ Whether to suppress all detections regardless of class_id. When
+ ``id_index`` is ``-1``, all valid boxes are treated as belonging to the
+ same class, so this flag has the same effect as ``True``.
+
+ top_k : int, optional
+ Keep maximum top k detections before nms, -1 for no limit.
+
+ coord_start : int, optional
+ Start index of the consecutive 4 coordinates.
+
+ score_index : int, optional
+ Index of the scores/confidence of boxes.
+
+ id_index : int, optional
+ Index of the class categories. Set to ``-1`` to suppress boxes across
+ all classes.
+
+ return_indices : bool, optional
+ Whether to return box indices in input data.
+
+ invalid_to_bottom : bool, optional
+ Whether to move valid bounding boxes to the top of the returned tensor.
+ This option only affects the ``return_indices=False`` path.
+
+ Returns
+ -------
+ out : relax.Expr
+ If ``return_indices`` is ``True``, returns
+ ``(box_indices, valid_box_count)`` with shapes
+ ``[batch_size, num_anchors]`` and ``[batch_size, 1]``.
+ Otherwise returns the modified data tensor.
+ """
+ return _ffi_api.non_max_suppression(
+ data,
+ valid_count,
+ indices,
+ max_output_size,
+ iou_threshold,
+ force_suppress,
+ top_k,
+ coord_start,
+ score_index,
+ id_index,
+ return_indices,
+ invalid_to_bottom,
+ )
diff --git a/python/tvm/relax/transform/legalize_ops/vision.py
b/python/tvm/relax/transform/legalize_ops/vision.py
index 28367a67a3..ea0458bfce 100644
--- a/python/tvm/relax/transform/legalize_ops/vision.py
+++ b/python/tvm/relax/transform/legalize_ops/vision.py
@@ -120,6 +120,36 @@ def _roi_align(bb: BlockBuilder, call: Call) -> Expr:
)
+@register_legalize("relax.vision.get_valid_counts")
+def _get_valid_counts(block_builder: BlockBuilder, call: Call) -> Expr:
+ return block_builder.call_te(
+ topi.vision.get_valid_counts,
+ call.args[0],
+ score_threshold=call.attrs.score_threshold,
+ id_index=call.attrs.id_index,
+ score_index=call.attrs.score_index,
+ )
+
+
+@register_legalize("relax.vision.non_max_suppression")
+def _non_max_suppression(block_builder: BlockBuilder, call: Call) -> Expr:
+ return block_builder.call_te(
+ topi.vision.non_max_suppression,
+ call.args[0],
+ call.args[1],
+ call.args[2],
+ max_output_size=call.attrs.max_output_size,
+ iou_threshold=call.attrs.iou_threshold,
+ force_suppress=call.attrs.force_suppress,
+ top_k=call.attrs.top_k,
+ coord_start=call.attrs.coord_start,
+ score_index=call.attrs.score_index,
+ id_index=call.attrs.id_index,
+ return_indices=call.attrs.return_indices,
+ invalid_to_bottom=call.attrs.invalid_to_bottom,
+ )
+
+
@register_legalize("relax.vision.multibox_transform_loc")
def _multibox_transform_loc(bb: BlockBuilder, call: Call) -> Expr:
variances = tuple(float(x) for x in call.attrs.variances)
diff --git a/python/tvm/topi/testing/__init__.py
b/python/tvm/topi/testing/__init__.py
index d9fd005921..143ccb8459 100644
--- a/python/tvm/topi/testing/__init__.py
+++ b/python/tvm/topi/testing/__init__.py
@@ -54,9 +54,11 @@ from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
from .gather_python import gather_python
from .gather_nd_python import gather_nd_python
+from .get_valid_counts_python import get_valid_counts_python
from .strided_slice_python import strided_slice_python, strided_set_python
from .batch_matmul import batch_matmul
from .batch_norm import batch_norm
+from .nms_python import non_max_suppression_python
from .slice_axis_python import slice_axis_python
from .sequence_mask_python import sequence_mask
from .poolnd_python import poolnd_python
diff --git a/python/tvm/topi/testing/get_valid_counts_python.py
b/python/tvm/topi/testing/get_valid_counts_python.py
new file mode 100644
index 0000000000..2caab6babc
--- /dev/null
+++ b/python/tvm/topi/testing/get_valid_counts_python.py
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Numpy reference implementation for get_valid_counts."""
+import numpy as np
+
+
+def get_valid_counts_python(data, score_threshold=0, id_index=0,
score_index=1):
+ """Numpy reference for get_valid_counts.
+
+ Parameters
+ ----------
+ data : numpy.ndarray
+ 3-D array with shape [batch_size, num_anchors, elem_length].
+
+ score_threshold : float
+ Lower limit of score for valid bounding boxes.
+
+ id_index : int
+ Index of the class categories, -1 to disable.
+
+ score_index : int
+ Index of the scores/confidence of boxes.
+
+ Returns
+ -------
+ valid_count : numpy.ndarray
+ 1-D array, shape [batch_size].
+
+ out_tensor : numpy.ndarray
+ Rearranged data, shape [batch_size, num_anchors, elem_length].
+
+ out_indices : numpy.ndarray
+ Indices mapping, shape [batch_size, num_anchors].
+ """
+ batch_size, num_anchors, box_data_length = data.shape
+ valid_count = np.zeros(batch_size, dtype="int32")
+ out_tensor = np.full_like(data, -1.0)
+ out_indices = np.full((batch_size, num_anchors), -1, dtype="int32")
+
+ for i in range(batch_size):
+ cnt = 0
+ for j in range(num_anchors):
+ score = data[i, j, score_index]
+ if id_index < 0:
+ is_valid = score > score_threshold
+ else:
+ is_valid = score > score_threshold and data[i, j, id_index] >= 0
+ if is_valid:
+ out_tensor[i, cnt, :] = data[i, j, :]
+ out_indices[i, cnt] = j
+ cnt += 1
+ valid_count[i] = cnt
+
+ return valid_count, out_tensor, out_indices
diff --git a/python/tvm/topi/testing/nms_python.py
b/python/tvm/topi/testing/nms_python.py
new file mode 100644
index 0000000000..7c8c20f5b4
--- /dev/null
+++ b/python/tvm/topi/testing/nms_python.py
@@ -0,0 +1,146 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Numpy reference implementation for classic non_max_suppression."""
+import numpy as np
+
+
+def _iou(box_a, box_b, coord_start):
+ """Compute IoU between two boxes."""
+ a = box_a[coord_start : coord_start + 4]
+ b = box_b[coord_start : coord_start + 4]
+
+ a_l, a_t, a_r, a_b = min(a[0], a[2]), min(a[1], a[3]), max(a[0], a[2]),
max(a[1], a[3])
+ b_l, b_t, b_r, b_b = min(b[0], b[2]), min(b[1], b[3]), max(b[0], b[2]),
max(b[1], b[3])
+
+ w = max(0.0, min(a_r, b_r) - max(a_l, b_l))
+ h = max(0.0, min(a_b, b_b) - max(a_t, b_t))
+ area = w * h
+ u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area
+ return 0.0 if u <= 0 else area / u
+
+
+def non_max_suppression_python(
+ data,
+ valid_count,
+ indices,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+):
+ """Numpy reference for classic non_max_suppression.
+
+ Parameters
+ ----------
+ data : numpy.ndarray
+ 3-D array, shape [batch_size, num_anchors, elem_length].
+
+ valid_count : numpy.ndarray
+ 1-D array, shape [batch_size].
+
+ indices : numpy.ndarray
+ 2-D array, shape [batch_size, num_anchors].
+
+ Returns
+ -------
+ If return_indices is True: (box_indices, valid_box_count)
+ Otherwise: modified data tensor
+ """
+ batch_size, num_anchors, _ = data.shape
+ out_data = np.full_like(data, -1.0)
+ out_box_indices = np.full((batch_size, num_anchors), -1, dtype="int32")
+ compacted = np.full((batch_size, num_anchors), -1, dtype="int32")
+ valid_box_count = np.zeros((batch_size, 1), dtype="int32")
+
+ for i in range(batch_size):
+ nkeep = int(valid_count[i])
+ if 0 < top_k < nkeep:
+ nkeep = top_k
+
+ # Sort by score descending
+ scores = data[i, :nkeep, score_index].copy()
+ sorted_idx = np.argsort(-scores)
+
+ # Copy sorted boxes
+ for j in range(nkeep):
+ src = sorted_idx[j]
+ out_data[i, j, :] = data[i, src, :]
+ out_box_indices[i, j] = src
+
+ # Greedy NMS
+ num_valid = 0
+ for j in range(nkeep):
+ if out_data[i, j, score_index] <= 0:
+ out_data[i, j, :] = -1.0
+ out_box_indices[i, j] = -1
+ continue
+ if 0 < max_output_size <= num_valid:
+ out_data[i, j, :] = -1.0
+ out_box_indices[i, j] = -1
+ continue
+
+ num_valid += 1
+
+ # Suppress overlapping boxes
+ for k in range(j + 1, nkeep):
+ if out_data[i, k, score_index] <= 0:
+ continue
+
+ do_suppress = False
+ if force_suppress:
+ do_suppress = True
+ elif id_index >= 0:
+ do_suppress = out_data[i, j, id_index] == out_data[i, k,
id_index]
+ else:
+ do_suppress = True
+
+ if do_suppress:
+ iou = _iou(out_data[i, j], out_data[i, k], coord_start)
+ if iou >= iou_threshold:
+ out_data[i, k, score_index] = -1.0
+ out_box_indices[i, k] = -1
+
+ if return_indices:
+ # Compact valid indices to top and remap to original
+ cnt = 0
+ for j in range(num_anchors):
+ if out_box_indices[i, j] >= 0:
+ orig_idx = out_box_indices[i, j]
+ compacted[i, cnt] = int(indices[i, orig_idx])
+ cnt += 1
+ valid_box_count[i, 0] = cnt
+
+ if return_indices:
+ return [compacted, valid_box_count]
+
+ if invalid_to_bottom:
+ # Rearrange valid boxes to top
+ result = np.full_like(data, -1.0)
+ for i in range(batch_size):
+ cnt = 0
+ for j in range(num_anchors):
+ if out_data[i, j, score_index] >= 0:
+ result[i, cnt, :] = out_data[i, j, :]
+ cnt += 1
+ return result
+
+ return out_data
diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py
index 9bdedc3535..b69e9c2aa1 100644
--- a/python/tvm/topi/vision/nms.py
+++ b/python/tvm/topi/vision/nms.py
@@ -36,37 +36,510 @@ from .nms_util import (
)
-def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): #
pylint: disable=unused-argument
+def _get_valid_counts_ir(
+ data, score_threshold, id_index, score_index, valid_count, out_tensor,
out_indices
+):
+ """IR for get_valid_counts. Filters boxes by score and compacts valid ones
to the top."""
+ batch_size = data.shape[0]
+ num_anchors = data.shape[1]
+ box_data_length = data.shape[2]
+
+ with IRBuilder() as ib:
+ data = T.buffer_proxy(data)
+ valid_count = T.buffer_proxy(valid_count)
+ out_tensor = T.buffer_proxy(out_tensor)
+ out_indices = T.buffer_proxy(out_indices)
+
+ with T.parallel(0, batch_size) as i:
+ valid_count[i] = T.int32(0)
+
+ with T.serial(0, num_anchors) as j:
+ score = data[i, j, score_index]
+ if id_index < 0:
+ is_valid = score > score_threshold
+ else:
+ is_valid = tvm.tirx.all(score > score_threshold, data[i,
j, id_index] >= 0)
+
+ with T.If(is_valid):
+ with T.Then():
+ cur = valid_count[i]
+ with T.serial(0, box_data_length) as k:
+ out_tensor[i, cur, k] = data[i, j, k]
+ out_indices[i, cur] = j
+ valid_count[i] = cur + 1
+
+ # Fill remaining slots with -1
+ with T.serial(0, num_anchors) as j:
+ with T.If(j >= valid_count[i]):
+ with T.Then():
+ with T.serial(0, box_data_length) as k:
+ out_tensor[i, j, k] = tvm.tirx.Cast(data.dtype,
T.float32(-1.0))
+ out_indices[i, j] = T.int32(-1)
+
+ return ib.get()
+
+
+def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
"""Get valid count of bounding boxes given a score threshold.
Also moves valid boxes to the top of input data.
+
Parameters
----------
data : tvm.te.Tensor
- Input data. 3-D tensor with shape [batch_size, num_anchors, 6]
- or [batch_size, num_anchors, 5].
+ Input data. 3-D tensor with shape [batch_size, num_anchors,
elem_length].
+
score_threshold : optional, float
Lower limit of score for valid bounding boxes.
+
id_index : optional, int
- index of the class categories, -1 to disable.
+ Index of the class categories, -1 to disable.
+
score_index: optional, int
Index of the scores/confidence of boxes.
+
Returns
-------
valid_count : tvm.te.Tensor
- 1-D tensor for valid number of boxes.
+ 1-D tensor for valid number of boxes, shape [batch_size].
+
out_tensor : tvm.te.Tensor
- Rearranged data tensor.
- out_indices: tvm.te.Tensor or numpy NDArray
- Related index in input data.
+ Rearranged data tensor, shape [batch_size, num_anchors, elem_length].
+
+ out_indices: tvm.te.Tensor
+ Related index in input data, shape [batch_size, num_anchors].
"""
- if isinstance(score_threshold, float | int):
+ batch_size = data.shape[0]
+ num_anchors = data.shape[1]
+ box_data_length = data.shape[2]
+
+ is_score_threshold_tensor = isinstance(score_threshold, te.Tensor)
+ if not is_score_threshold_tensor:
score_threshold = tvm.tirx.const(score_threshold, dtype=data.dtype)
- # id_index_const = tvm.tirx.const(id_index, "int32") # Unused
- # score_index_const = tvm.tirx.const(score_index, "int32") # Unused
- return (
- te.compute((data.shape[0],), lambda i: data.shape[1],
name="valid_count"),
- data,
- te.compute((data.shape[0], data.shape[1]), lambda i, j: j,
name="out_indices"),
+
+ id_index_const = tvm.tirx.const(id_index, "int32")
+ score_index_const = tvm.tirx.const(score_index, "int32")
+
+ valid_count_buf = tvm.tirx.decl_buffer((batch_size,), "int32",
"valid_count")
+ out_tensor_buf = tvm.tirx.decl_buffer(
+ (batch_size, num_anchors, box_data_length), data.dtype, "out_tensor"
+ )
+ out_indices_buf = tvm.tirx.decl_buffer(
+ (batch_size, num_anchors), "int32", "out_indices"
+ )
+
+ if is_score_threshold_tensor:
+ score_thresh_buf = tvm.tirx.decl_buffer(
+ score_threshold.shape, score_threshold.dtype, "score_threshold"
+ )
+ valid_count, out_tensor, out_indices = te.extern(
+ [(batch_size,), (batch_size, num_anchors, box_data_length),
(batch_size, num_anchors)],
+ [data, score_threshold],
+ lambda ins, outs: _get_valid_counts_ir(
+ ins[0], ins[1], id_index_const, score_index_const,
+ outs[0], outs[1], outs[2],
+ ),
+ dtype=["int32", data.dtype, "int32"],
+ out_buffers=[valid_count_buf, out_tensor_buf, out_indices_buf],
+ in_buffers=[
+ tvm.tirx.decl_buffer(data.shape, data.dtype, "data"),
+ score_thresh_buf,
+ ],
+ name="get_valid_counts",
+ tag="get_valid_counts",
+ )
+ else:
+ # score_threshold is a TIR constant, not a tensor
+ def _ir_with_const_threshold(ins, outs):
+ return _get_valid_counts_ir(
+ ins[0], score_threshold, id_index_const, score_index_const,
+ outs[0], outs[1], outs[2],
+ )
+
+ valid_count, out_tensor, out_indices = te.extern(
+ [(batch_size,), (batch_size, num_anchors, box_data_length),
(batch_size, num_anchors)],
+ [data],
+ _ir_with_const_threshold,
+ dtype=["int32", data.dtype, "int32"],
+ out_buffers=[valid_count_buf, out_tensor_buf, out_indices_buf],
+ in_buffers=[tvm.tirx.decl_buffer(data.shape, data.dtype, "data")],
+ name="get_valid_counts",
+ tag="get_valid_counts",
+ )
+
+ return valid_count, out_tensor, out_indices
+
+
+def _classic_nms_ir(
+ data,
+ sorted_index,
+ valid_count,
+ indices,
+ batch_size,
+ num_anchors,
+ box_data_length,
+ max_output_size,
+ iou_threshold,
+ force_suppress,
+ top_k,
+ coord_start,
+ score_index,
+ id_index,
+ return_indices,
+ out_data,
+ out_box_indices,
+ out_valid_box_count,
+):
+ """IR for classic single-class non-maximum suppression."""
+ with IRBuilder() as ib:
+ data = T.buffer_proxy(data)
+ sorted_index = T.buffer_proxy(sorted_index)
+ valid_count = T.buffer_proxy(valid_count)
+ indices = T.buffer_proxy(indices)
+ out_data = T.buffer_proxy(out_data)
+ out_box_indices = T.buffer_proxy(out_box_indices)
+ if out_valid_box_count is not None:
+ out_valid_box_count = T.buffer_proxy(out_valid_box_count)
+
+ with T.parallel(0, batch_size) as i:
+ # Step 1: Reorder data by sorted score
+ nkeep_buf = T.alloc_buffer((1,), "int32", scope="local")
+ nkeep_local = T.buffer_proxy(nkeep_buf)
+ nkeep_local[0] = valid_count[i]
+ with T.If(tvm.tirx.all(top_k > 0, top_k < nkeep_local[0])):
+ with T.Then():
+ nkeep_local[0] = top_k
+
+ # Copy sorted boxes to output
+ with T.serial(0, num_anchors) as j:
+ with T.If(j < nkeep_local[0]):
+ with T.Then():
+ src_idx = sorted_index[i, j]
+ with T.serial(0, box_data_length) as k:
+ out_data[i, j, k] = data[i, src_idx, k]
+ out_box_indices[i, j] = sorted_index[i, j]
+ with T.Else():
+ with T.serial(0, box_data_length) as k:
+ out_data[i, j, k] = tvm.tirx.Cast(data.dtype,
T.float32(-1.0))
+ out_box_indices[i, j] = T.int32(-1)
+
+ # Step 2: Apply NMS - greedy suppression
+ num_valid_boxes_buf = T.alloc_buffer((1,), "int32", scope="local")
+ num_valid_boxes = T.buffer_proxy(num_valid_boxes_buf)
+ num_valid_boxes[0] = T.int32(0)
+
+ with T.serial(0, nkeep_local[0]) as j:
+ # Check if box j is still valid (score > 0) and within
max_output_size
+ with T.If(
+ tvm.tirx.all(
+ out_data[i, j, score_index] >
tvm.tirx.Cast(data.dtype, T.float32(0.0)),
+ tvm.tirx.Select(
+ max_output_size > 0,
+ num_valid_boxes[0] < max_output_size,
+ tvm.tirx.const(True),
+ ),
+ )
+ ):
+ with T.Then():
+ num_valid_boxes[0] = num_valid_boxes[0] + 1
+
+ # Suppress overlapping boxes
+ with T.serial(0, nkeep_local[0]) as k:
+ with T.If(
+ tvm.tirx.all(
+ k > j,
+ out_data[i, k, score_index]
+ > tvm.tirx.Cast(data.dtype,
T.float32(0.0)),
+ )
+ ):
+ with T.Then():
+ # Check class ID match (or force_suppress)
+ do_suppress = tvm.tirx.const(False)
+ if force_suppress:
+ do_suppress = tvm.tirx.const(True)
+ elif id_index >= 0:
+ do_suppress = (
+ out_data[i, j, id_index] ==
out_data[i, k, id_index]
+ )
+ else:
+ do_suppress = tvm.tirx.const(True)
+
+ with T.If(do_suppress):
+ with T.Then():
+ # Calculate IoU
+ a_l = tvm.te.min(
+ out_data[i, j, coord_start],
+ out_data[i, j, coord_start +
2],
+ )
+ a_t = tvm.te.min(
+ out_data[i, j, coord_start +
1],
+ out_data[i, j, coord_start +
3],
+ )
+ a_r = tvm.te.max(
+ out_data[i, j, coord_start],
+ out_data[i, j, coord_start +
2],
+ )
+ a_b = tvm.te.max(
+ out_data[i, j, coord_start +
1],
+ out_data[i, j, coord_start +
3],
+ )
+
+ b_l = tvm.te.min(
+ out_data[i, k, coord_start],
+ out_data[i, k, coord_start +
2],
+ )
+ b_t = tvm.te.min(
+ out_data[i, k, coord_start +
1],
+ out_data[i, k, coord_start +
3],
+ )
+ b_r = tvm.te.max(
+ out_data[i, k, coord_start],
+ out_data[i, k, coord_start +
2],
+ )
+ b_b = tvm.te.max(
+ out_data[i, k, coord_start +
1],
+ out_data[i, k, coord_start +
3],
+ )
+
+ w = tvm.te.max(
+ tvm.tirx.Cast(data.dtype,
T.float32(0.0)),
+ tvm.te.min(a_r, b_r) -
tvm.te.max(a_l, b_l),
+ )
+ h = tvm.te.max(
+ tvm.tirx.Cast(data.dtype,
T.float32(0.0)),
+ tvm.te.min(a_b, b_b) -
tvm.te.max(a_t, b_t),
+ )
+ area = h * w
+ u = (
+ (a_r - a_l) * (a_b - a_t)
+ + (b_r - b_l) * (b_b - b_t)
+ - area
+ )
+ iou = tvm.tirx.Select(
+ u <= tvm.tirx.Cast(data.dtype,
T.float32(0.0)),
+ tvm.tirx.Cast(data.dtype,
T.float32(0.0)),
+ area / u,
+ )
+
+ with T.If(iou >= iou_threshold):
+ with T.Then():
+ out_data[i, k,
score_index] = tvm.tirx.Cast(
+ data.dtype,
T.float32(-1.0)
+ )
+ out_box_indices[i, k] =
T.int32(-1)
+
+ with T.Else():
+ # Box suppressed or beyond max_output_size
+ with T.serial(0, box_data_length) as k:
+ out_data[i, j, k] = tvm.tirx.Cast(data.dtype,
T.float32(-1.0))
+ out_box_indices[i, j] = T.int32(-1)
+
+ # Step 3: If return_indices, remap to original indices
+ if return_indices:
+ if out_valid_box_count is not None:
+ # Count valid boxes and remap indices
+ valid_idx_buf = T.alloc_buffer((1,), "int32",
scope="local")
+ valid_idx = T.buffer_proxy(valid_idx_buf)
+ valid_idx[0] = T.int32(0)
+
+ with T.serial(0, num_anchors) as j:
+ with T.If(out_box_indices[i, j] >= 0):
+ with T.Then():
+ orig_idx = out_box_indices[i, j]
+ out_box_indices[i, valid_idx[0]] = indices[i,
orig_idx]
+ valid_idx[0] = valid_idx[0] + 1
+
+ out_valid_box_count[i, 0] = valid_idx[0]
+
+ # Fill remaining with -1
+ with T.serial(0, num_anchors) as j:
+ with T.If(j >= valid_idx[0]):
+ with T.Then():
+ out_box_indices[i, j] = T.int32(-1)
+
+ return ib.get()
+
+
+def non_max_suppression(
+ data,
+ valid_count,
+ indices,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+):
+ """Non-maximum suppression operator for object detection.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ 3-D tensor with shape [batch_size, num_anchors, elem_length].
+
+ valid_count : tvm.te.Tensor
+ 1-D tensor for valid number of boxes, shape [batch_size].
+
+ indices : tvm.te.Tensor
+ 2-D tensor with shape [batch_size, num_anchors].
+
+ max_output_size : optional, int
+ Max number of output valid boxes for each instance.
+ Return all valid boxes if the value is less than 0.
+
+ iou_threshold : optional, float
+ Non-maximum suppression IoU threshold.
+
+ force_suppress : optional, boolean
+ Whether to suppress all detections regardless of class_id. When
+ ``id_index`` is ``-1``, all valid boxes are treated as belonging to the
+ same class, so this flag has the same effect as ``True``.
+
+ top_k : optional, int
+ Keep maximum top k detections before nms, -1 for no limit.
+
+ coord_start : required, int
+ Start index of the consecutive 4 coordinates.
+
+ score_index: optional, int
+ Index of the scores/confidence of boxes.
+
+ id_index : optional, int
+ Index of the class categories, -1 to disable.
+
+ return_indices : optional, boolean
+ Whether to return box indices in input data.
+
+ invalid_to_bottom : optional, boolean
+ Whether to move all valid bounding boxes to the top.
+
+ Returns
+ -------
+ out : tvm.te.Tensor or tuple of tvm.te.Tensor
+ If return_indices is True, returns a tuple of (box_indices,
valid_box_count).
+ Otherwise returns the modified data tensor.
+ """
+ batch_size = data.shape[0]
+ num_anchors = data.shape[1]
+ box_data_length = data.shape[2]
+
+ if isinstance(max_output_size, int):
+ max_output_size = tvm.tirx.const(max_output_size, dtype="int32")
+ if isinstance(iou_threshold, (float, int)):
+ iou_threshold = tvm.tirx.const(iou_threshold, dtype=data.dtype)
+
+ # Sort by score
+ score_shape = (batch_size, num_anchors)
+ score_tensor = te.compute(
+ score_shape, lambda i, j: data[i, j, score_index], name="score_tensor"
+ )
+ sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1,
is_ascend=False)
+
+ data_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "data")
+ sort_buf = tvm.tirx.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
"sorted_index")
+ valid_count_buf = tvm.tirx.decl_buffer(valid_count.shape,
valid_count.dtype, "valid_count")
+ indices_buf = tvm.tirx.decl_buffer(indices.shape, indices.dtype, "indices")
+
+ out_data_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "out_data")
+ out_box_indices_buf = tvm.tirx.decl_buffer(
+ (batch_size, num_anchors), "int32", "out_box_indices"
+ )
+
+ if return_indices:
+ out_valid_box_count_buf = tvm.tirx.decl_buffer(
+ (batch_size, 1), "int32", "out_valid_box_count"
+ )
+
+ out_data, out_box_indices, out_valid_box_count = te.extern(
+ [data.shape, (batch_size, num_anchors), (batch_size, 1)],
+ [data, sort_tensor, valid_count, indices],
+ lambda ins, outs: _classic_nms_ir(
+ ins[0], ins[1], ins[2], ins[3],
+ batch_size, num_anchors, box_data_length,
+ max_output_size, iou_threshold,
+ force_suppress, top_k,
+ coord_start, score_index, id_index,
+ return_indices,
+ outs[0], outs[1], outs[2],
+ ),
+ dtype=[data.dtype, "int32", "int32"],
+ out_buffers=[out_data_buf, out_box_indices_buf,
out_valid_box_count_buf],
+ in_buffers=[data_buf, sort_buf, valid_count_buf, indices_buf],
+ name="non_max_suppression",
+ tag="non_max_suppression",
+ )
+ return [out_box_indices, out_valid_box_count]
+
+ out_data, out_box_indices = te.extern(
+ [data.shape, (batch_size, num_anchors)],
+ [data, sort_tensor, valid_count, indices],
+ lambda ins, outs: _classic_nms_ir(
+ ins[0], ins[1], ins[2], ins[3],
+ batch_size, num_anchors, box_data_length,
+ max_output_size, iou_threshold,
+ force_suppress, top_k,
+ coord_start, score_index, id_index,
+ return_indices,
+ outs[0], outs[1], None,
+ ),
+ dtype=[data.dtype, "int32"],
+ out_buffers=[out_data_buf, out_box_indices_buf],
+ in_buffers=[data_buf, sort_buf, valid_count_buf, indices_buf],
+ name="non_max_suppression",
+ tag="non_max_suppression",
+ )
+
+ if invalid_to_bottom:
+ # Rearrange to move valid boxes to top
+ return _rearrange_out(out_data, batch_size, num_anchors,
box_data_length, score_index)
+
+ return out_data
+
+
+def _rearrange_out(data, batch_size, num_anchors, box_data_length,
score_index):
+ """Move valid boxes (score >= 0) to the top of output."""
+ out_buf = tvm.tirx.decl_buffer(
+ (batch_size, num_anchors, box_data_length), data.dtype, "rearranged"
+ )
+
+ def _rearrange_ir(ins, outs):
+ with IRBuilder() as ib:
+ data = T.buffer_proxy(ins[0])
+ out = T.buffer_proxy(outs[0])
+
+ with T.parallel(0, batch_size) as i:
+ valid_idx_buf = T.alloc_buffer((1,), "int32", scope="local")
+ valid_idx = T.buffer_proxy(valid_idx_buf)
+ valid_idx[0] = T.int32(0)
+
+ with T.serial(0, num_anchors) as j:
+ with T.If(
+ data[i, j, score_index] >= tvm.tirx.Cast(data.dtype,
T.float32(0.0))
+ ):
+ with T.Then():
+ with T.serial(0, box_data_length) as k:
+ out[i, valid_idx[0], k] = data[i, j, k]
+ valid_idx[0] = valid_idx[0] + 1
+
+ with T.serial(0, num_anchors) as j:
+ with T.If(j >= valid_idx[0]):
+ with T.Then():
+ with T.serial(0, box_data_length) as k:
+ out[i, j, k] = tvm.tirx.Cast(data.dtype,
T.float32(-1.0))
+
+ return ib.get()
+
+ return te.extern(
+ [(batch_size, num_anchors, box_data_length)],
+ [data],
+ _rearrange_ir,
+ dtype=[data.dtype],
+ out_buffers=[out_buf],
+ name="rearrange_out",
+ tag="rearrange_out",
)
diff --git a/python/tvm/topi/vision/nms_util.py
b/python/tvm/topi/vision/nms_util.py
index ae17168970..a4b4c78363 100644
--- a/python/tvm/topi/vision/nms_util.py
+++ b/python/tvm/topi/vision/nms_util.py
@@ -303,8 +303,8 @@ def _all_class_nms_ir(
if selected_scores is not None:
selected_scores = T.buffer_proxy(selected_scores)
- if isinstance(iou_threshold, float):
- iou_threshold = tvm.tirx.FloatImm("float32", iou_threshold)
+ if isinstance(iou_threshold, (float, int)):
+ iou_threshold = tvm.tirx.FloatImm("float32", float(iou_threshold))
elif isinstance(iou_threshold, te.Tensor):
if len(iou_threshold.shape) == 0:
iou_threshold = iou_threshold()
diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc
index 294cd40c45..97508d7211 100644
--- a/src/relax/op/vision/nms.cc
+++ b/src/relax/op/vision/nms.cc
@@ -18,6 +18,7 @@
*/
#include "nms.h"
+#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <tvm/ir/attrs.h>
@@ -33,7 +34,11 @@
namespace tvm {
namespace relax {
-TVM_FFI_STATIC_INIT_BLOCK() {
AllClassNonMaximumSuppressionAttrs::RegisterReflection(); }
+TVM_FFI_STATIC_INIT_BLOCK() {
+ AllClassNonMaximumSuppressionAttrs::RegisterReflection();
+ GetValidCountsAttrs::RegisterReflection();
+ NonMaximumSuppressionAttrs::RegisterReflection();
+}
/* relax.vision.all_class_non_max_suppression */
@@ -110,5 +115,242 @@
TVM_REGISTER_OP("relax.vision.all_class_non_max_suppression")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAllClassNMS)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.vision.get_valid_counts */
+
+Expr get_valid_counts(Expr data, double score_threshold, int id_index, int
score_index) {
+ auto attrs = tvm::ffi::make_object<GetValidCountsAttrs>();
+ attrs->score_threshold = score_threshold;
+ attrs->id_index = id_index;
+ attrs->score_index = score_index;
+
+ static const Op& op = Op::Get("relax.vision.get_valid_counts");
+ return Call(op, {std::move(data)}, Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("relax.op.vision.get_valid_counts", get_valid_counts);
+}
+
+StructInfo InferStructInfoGetValidCounts(const Call& call, const BlockBuilder&
ctx) {
+ if (call->args.size() != 1) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "get_valid_counts expects 1 argument, got " <<
call->args.size());
+ }
+
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ if (data_sinfo == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "get_valid_counts expects input data to be a Tensor.");
+ }
+ if (data_sinfo->ndim != -1 && data_sinfo->ndim != 3) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "get_valid_counts expects 3-D input, got ndim " <<
data_sinfo->ndim);
+ }
+
+ const auto* attrs = call->attrs.as<GetValidCountsAttrs>();
+ TVM_FFI_ICHECK(attrs != nullptr) << "Invalid get_valid_counts attrs";
+ auto vdev = data_sinfo->vdevice;
+ const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+ if (data_shape == nullptr) {
+ tvm::ffi::Array<StructInfo> fields = {
+ TensorStructInfo(DataType::Int(32), /*ndim=*/1, vdev),
+ TensorStructInfo(data_sinfo->dtype, /*ndim=*/3, vdev),
+ TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev)};
+ return TupleStructInfo(fields);
+ }
+
+ auto batch = data_shape->values[0];
+ auto num_anchors = data_shape->values[1];
+ auto elem_length = data_shape->values[2];
+ const auto* elem_length_imm = elem_length.as<IntImmNode>();
+ if (elem_length_imm != nullptr) {
+ if (attrs->score_index < 0 || attrs->score_index >=
elem_length_imm->value) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "get_valid_counts expects score_index to be in range
[0, "
+ << elem_length_imm->value << "), but got " <<
attrs->score_index);
+ }
+ if (attrs->id_index < -1 || attrs->id_index >= elem_length_imm->value) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "get_valid_counts expects id_index to be in range
[-1, "
+ << elem_length_imm->value << "), but got " <<
attrs->id_index);
+ }
+ }
+
+ tvm::ffi::Array<StructInfo> fields = {
+ TensorStructInfo(ShapeExpr({batch}), DataType::Int(32), vdev),
+ TensorStructInfo(ShapeExpr({batch, num_anchors, elem_length}),
data_sinfo->dtype, vdev),
+ TensorStructInfo(ShapeExpr({batch, num_anchors}), DataType::Int(32),
vdev)};
+ return TupleStructInfo(fields);
+}
+
+TVM_REGISTER_OP("relax.vision.get_valid_counts")
+ .set_attrs_type<GetValidCountsAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor",
+ "Input data, 3-D tensor [batch_size, num_anchors,
elem_length].")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoGetValidCounts)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+/* relax.vision.non_max_suppression */
+
+Expr non_max_suppression(Expr data, Expr valid_count, Expr indices, int
max_output_size,
+ double iou_threshold, bool force_suppress, int top_k,
int coord_start,
+ int score_index, int id_index, bool return_indices,
+ bool invalid_to_bottom) {
+ auto attrs = tvm::ffi::make_object<NonMaximumSuppressionAttrs>();
+ attrs->max_output_size = max_output_size;
+ attrs->iou_threshold = iou_threshold;
+ attrs->force_suppress = force_suppress;
+ attrs->top_k = top_k;
+ attrs->coord_start = coord_start;
+ attrs->score_index = score_index;
+ attrs->id_index = id_index;
+ attrs->return_indices = return_indices;
+ attrs->invalid_to_bottom = invalid_to_bottom;
+
+ static const Op& op = Op::Get("relax.vision.non_max_suppression");
+ return Call(op, {std::move(data), std::move(valid_count),
std::move(indices)}, Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("relax.op.vision.non_max_suppression",
non_max_suppression);
+}
+
+StructInfo InferStructInfoNMS(const Call& call, const BlockBuilder& ctx) {
+ if (call->args.size() != 3) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "non_max_suppression expects 3 arguments, got " <<
call->args.size());
+ }
+
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ const auto* valid_count_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+ const auto* indices_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
+ if (data_sinfo == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "non_max_suppression expects input data to be a
Tensor.");
+ }
+ if (valid_count_sinfo == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "non_max_suppression expects valid_count to be a
Tensor.");
+ }
+ if (indices_sinfo == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "non_max_suppression expects indices to be a Tensor.");
+ }
+ if (data_sinfo->ndim != -1 && data_sinfo->ndim != 3) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "non_max_suppression expects 3-D input, got ndim " <<
data_sinfo->ndim);
+ }
+ if (valid_count_sinfo->ndim != -1 && valid_count_sinfo->ndim != 1) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "non_max_suppression expects valid_count to be 1-D,
got ndim "
+ << valid_count_sinfo->ndim);
+ }
+ if (indices_sinfo->ndim != -1 && indices_sinfo->ndim != 2) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "non_max_suppression expects indices to be 2-D, got
ndim "
+ << indices_sinfo->ndim);
+ }
+ if (!valid_count_sinfo->IsUnknownDtype() && valid_count_sinfo->dtype !=
DataType::Int(32)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "non_max_suppression expects valid_count to have dtype
int32, got "
+ << valid_count_sinfo->dtype);
+ }
+ if (!indices_sinfo->IsUnknownDtype() && indices_sinfo->dtype !=
DataType::Int(32)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "non_max_suppression expects indices to have dtype
int32, got "
+ << indices_sinfo->dtype);
+ }
+
+ const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+ const auto* valid_count_shape = valid_count_sinfo->shape.as<ShapeExprNode>();
+ const auto* indices_shape = indices_sinfo->shape.as<ShapeExprNode>();
+ if (data_shape != nullptr) {
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+ PrimExpr batch = data_shape->values[0];
+ PrimExpr num_anchors = data_shape->values[1];
+ if (valid_count_shape != nullptr &&
+ !analyzer->CanProveEqual(valid_count_shape->values[0], batch)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "non_max_suppression expects valid_count to have
shape [batch_size]. "
+ "However, the given data tensor has batch size `"
+ << batch << "` and the given valid_count tensor has
shape "
+ << valid_count_sinfo->shape);
+ }
+ if (indices_shape != nullptr) {
+ if (!analyzer->CanProveEqual(indices_shape->values[0], batch) ||
+ !analyzer->CanProveEqual(indices_shape->values[1], num_anchors)) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "non_max_suppression expects indices to have shape [batch_size,
num_anchors]. "
+ "However, the given data tensor has shape "
+ << data_sinfo->shape << " and the given indices tensor has shape "
+ << indices_sinfo->shape);
+ }
+ }
+ }
+
+ const auto* attrs = call->attrs.as<NonMaximumSuppressionAttrs>();
+ TVM_FFI_ICHECK(attrs != nullptr) << "Invalid non_max_suppression attrs";
+ auto vdev = data_sinfo->vdevice;
+ if (data_shape != nullptr) {
+ const auto* elem_length_imm = data_shape->values[2].as<IntImmNode>();
+ if (elem_length_imm != nullptr) {
+ int64_t elem_length = elem_length_imm->value;
+ if (attrs->score_index < 0 || attrs->score_index >= elem_length) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "non_max_suppression expects score_index to be in
range [0, "
+ << elem_length << "), but got " <<
attrs->score_index);
+ }
+ if (attrs->coord_start < 0 || attrs->coord_start + 3 >= elem_length) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "non_max_suppression expects coord_start to
reference four "
+ "consecutive box coordinates within elem_length "
+ << elem_length << ", but got " << attrs->coord_start);
+ }
+ if (attrs->id_index < -1 || attrs->id_index >= elem_length) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "non_max_suppression expects id_index to be in
range [-1, "
+ << elem_length << "), but got " << attrs->id_index);
+ }
+ }
+ }
+
+ if (attrs->return_indices) {
+ // Returns (box_indices[batch, num_anchors], valid_box_count[batch, 1])
+ if (data_shape == nullptr) {
+ tvm::ffi::Array<StructInfo> fields = {
+ TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev),
+ TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev)};
+ return TupleStructInfo(fields);
+ }
+ auto batch = data_shape->values[0];
+ auto num_anchors = data_shape->values[1];
+ tvm::ffi::Array<StructInfo> fields = {
+ TensorStructInfo(ShapeExpr({batch, num_anchors}), DataType::Int(32),
vdev),
+ TensorStructInfo(ShapeExpr({batch, IntImm(DataType::Int(64), 1)}),
DataType::Int(32),
+ vdev)};
+ return TupleStructInfo(fields);
+ }
+
+ // Returns modified data tensor with the same shape as input.
+ if (const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>()) {
+ return TensorStructInfo(ffi::GetRef<ShapeExpr>(data_shape),
data_sinfo->dtype, vdev);
+ }
+ return TensorStructInfo(data_sinfo->dtype, /*ndim=*/3, vdev);
+}
+
+TVM_REGISTER_OP("relax.vision.non_max_suppression")
+ .set_attrs_type<NonMaximumSuppressionAttrs>()
+ .set_num_inputs(3)
+ .add_argument("data", "Tensor",
+ "Input data, 3-D tensor [batch_size, num_anchors,
elem_length].")
+ .add_argument("valid_count", "Tensor", "1-D tensor for valid number of
boxes.")
+ .add_argument("indices", "Tensor", "2-D tensor with shape [batch_size,
num_anchors].")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoNMS)
+ .set_attr<Bool>("FPurity", Bool(true));
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/vision/nms.h b/src/relax/op/vision/nms.h
index c86bf98c94..3fbd2609e2 100644
--- a/src/relax/op/vision/nms.h
+++ b/src/relax/op/vision/nms.h
@@ -38,6 +38,15 @@ Expr all_class_non_max_suppression(Expr boxes, Expr scores,
Expr max_output_boxe
Expr iou_threshold, Expr score_threshold,
ffi::String output_format);
+/*! \brief Get valid count of bounding boxes given a score threshold. */
+Expr get_valid_counts(Expr data, double score_threshold, int id_index, int
score_index);
+
+/*! \brief Non-maximum suppression for object detection. */
+Expr non_max_suppression(Expr data, Expr valid_count, Expr indices, int
max_output_size,
+ double iou_threshold, bool force_suppress, int top_k,
int coord_start,
+ int score_index, int id_index, bool return_indices,
+ bool invalid_to_bottom);
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_op_vision.py
b/tests/python/relax/test_op_vision.py
index cded9f5f29..6d04a796ca 100644
--- a/tests/python/relax/test_op_vision.py
+++ b/tests/python/relax/test_op_vision.py
@@ -20,6 +20,7 @@ import pytest
import tvm
import tvm.testing
+import tvm.topi.testing
from tvm import TVMError, relax, tirx
from tvm.ir import Op
from tvm.relax.transform import LegalizeOps
@@ -31,6 +32,23 @@ def _check_inference(bb: relax.BlockBuilder, call:
relax.Call, expected_sinfo: r
tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+def _assert_relax_op_legalized(mod: tvm.IRModule, op_name: str) -> None:
+ seen_call_tir = False
+ seen_original_op = False
+
+ def _visit(expr):
+ nonlocal seen_call_tir, seen_original_op
+ if isinstance(expr, relax.Call) and isinstance(expr.op, tvm.ir.Op):
+ if expr.op.name == "relax.call_tir":
+ seen_call_tir = True
+ if expr.op.name == op_name:
+ seen_original_op = True
+
+ relax.analysis.post_order_visit(mod["main"].body, _visit)
+ assert seen_call_tir
+ assert not seen_original_op
+
+
def test_roi_align_op_correctness():
x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
rois = relax.Var("rois", R.Tensor((4, 5), "float32"))
@@ -198,6 +216,840 @@ def test_roi_align_legalize_sample_ratio_zero():
)
+def test_get_valid_counts_op_correctness():
+ data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+ assert relax.op.vision.get_valid_counts(data, 0.5).op ==
Op.get("relax.vision.get_valid_counts")
+
+
+def test_get_valid_counts_infer_struct_info():
+ bb = relax.BlockBuilder()
+ data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+ _check_inference(
+ bb,
+ relax.op.vision.get_valid_counts(data, score_threshold=0.5,
id_index=0, score_index=1),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((2,), "int32"),
+ relax.TensorStructInfo((2, 10, 6), "float32"),
+ relax.TensorStructInfo((2, 10), "int32"),
+ ]
+ ),
+ )
+
+
+def test_get_valid_counts_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ n = tirx.Var("n", "int64")
+ m = tirx.Var("m", "int64")
+ k = tirx.Var("k", "int64")
+ data = relax.Var("data", R.Tensor((n, m, k), "float32"))
+ _check_inference(
+ bb,
+ relax.op.vision.get_valid_counts(data, score_threshold=0.0),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((n,), "int32"),
+ relax.TensorStructInfo((n, m, k), "float32"),
+ relax.TensorStructInfo((n, m), "int32"),
+ ]
+ ),
+ )
+
+
+def test_get_valid_counts_wrong_ndim():
+ bb = relax.BlockBuilder()
+ data = relax.Var("data", R.Tensor((10, 6), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.get_valid_counts(data))
+
+
+def test_get_valid_counts_invalid_indices():
+ bb = relax.BlockBuilder()
+ data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.get_valid_counts(data, score_index=6))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.get_valid_counts(data, id_index=6))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.get_valid_counts(data, id_index=-2))
+
+
+def test_nms_op_correctness():
+ data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+ valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+ indices = relax.Var("indices", R.Tensor((2, 10), "int32"))
+ assert relax.op.vision.non_max_suppression(
+ data, valid_count, indices
+ ).op == Op.get("relax.vision.non_max_suppression")
+
+
+def test_nms_infer_struct_info_return_indices():
+ bb = relax.BlockBuilder()
+ data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+ valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+ indices = relax.Var("indices", R.Tensor((2, 10), "int32"))
+ _check_inference(
+ bb,
+ relax.op.vision.non_max_suppression(
+ data, valid_count, indices, return_indices=True
+ ),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((2, 10), "int32"),
+ relax.TensorStructInfo((2, 1), "int32"),
+ ]
+ ),
+ )
+
+
+def test_nms_infer_struct_info_return_data():
+ bb = relax.BlockBuilder()
+ data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+ valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+ indices = relax.Var("indices", R.Tensor((2, 10), "int32"))
+ _check_inference(
+ bb,
+ relax.op.vision.non_max_suppression(
+ data, valid_count, indices, return_indices=False
+ ),
+ relax.TensorStructInfo((2, 10, 6), "float32"),
+ )
+
+
+def test_nms_infer_struct_info_return_data_shape_var():
+ bb = relax.BlockBuilder()
+ batch_size = tirx.Var("batch_size", "int64")
+ num_anchors = tirx.Var("num_anchors", "int64")
+ elem_length = tirx.Var("elem_length", "int64")
+ data = relax.Var("data", R.Tensor((batch_size, num_anchors, elem_length),
"float32"))
+ valid_count = relax.Var("valid_count", R.Tensor((batch_size,), "int32"))
+ indices = relax.Var("indices", R.Tensor((batch_size, num_anchors),
"int32"))
+ _check_inference(
+ bb,
+ relax.op.vision.non_max_suppression(
+ data, valid_count, indices, return_indices=False
+ ),
+ relax.TensorStructInfo((batch_size, num_anchors, elem_length),
"float32"),
+ )
+
+
+def test_nms_wrong_ndim():
+ bb = relax.BlockBuilder()
+ data = relax.Var("data", R.Tensor((10, 6), "float32"))
+ valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+ indices = relax.Var("indices", R.Tensor((2, 10), "int32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.non_max_suppression(data, valid_count,
indices))
+
+
+def test_nms_wrong_valid_count_ndim():
+ bb = relax.BlockBuilder()
+ data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+ valid_count = relax.Var("valid_count", R.Tensor((2, 1), "int32"))
+ indices = relax.Var("indices", R.Tensor((2, 10), "int32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.non_max_suppression(data, valid_count,
indices))
+
+
+def test_nms_wrong_indices_ndim():
+ bb = relax.BlockBuilder()
+ data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+ valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+ indices = relax.Var("indices", R.Tensor((20,), "int32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.non_max_suppression(data, valid_count,
indices))
+
+
+def test_nms_wrong_aux_input_dtype():
+ bb = relax.BlockBuilder()
+ data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+ valid_count_i64 = relax.Var("valid_count_i64", R.Tensor((2,), "int64"))
+ valid_count_i32 = relax.Var("valid_count_i32", R.Tensor((2,), "int32"))
+ indices_i64 = relax.Var("indices_i64", R.Tensor((2, 10), "int64"))
+ indices_i32 = relax.Var("indices_i32", R.Tensor((2, 10), "int32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.non_max_suppression(data,
valid_count_i64, indices_i32))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.non_max_suppression(data,
valid_count_i32, indices_i64))
+
+
+def test_nms_wrong_aux_input_shape():
+ bb = relax.BlockBuilder()
+ data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+ valid_count_bad_batch = relax.Var("valid_count_bad_batch", R.Tensor((3,),
"int32"))
+ valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+ indices_bad_batch = relax.Var("indices_bad_batch", R.Tensor((3, 10),
"int32"))
+ indices_bad_anchors = relax.Var("indices_bad_anchors", R.Tensor((2, 9),
"int32"))
+ with pytest.raises(TVMError):
+ bb.normalize(
+ relax.op.vision.non_max_suppression(
+ data, valid_count_bad_batch, indices_bad_anchors
+ )
+ )
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.non_max_suppression(data, valid_count,
indices_bad_batch))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.non_max_suppression(data, valid_count,
indices_bad_anchors))
+
+
+def test_nms_invalid_indices():
+ bb = relax.BlockBuilder()
+ data = relax.Var("data", R.Tensor((2, 10, 6), "float32"))
+ valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+ indices = relax.Var("indices", R.Tensor((2, 10), "int32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.non_max_suppression(data, valid_count,
indices, score_index=6))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.non_max_suppression(data, valid_count,
indices, id_index=6))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.non_max_suppression(data, valid_count,
indices, id_index=-2))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.non_max_suppression(data, valid_count,
indices, coord_start=3))
+
+
+def test_get_valid_counts_legalize():
+ @tvm.script.ir_module
+ class GVC:
+ @R.function
+ def main(
+ data: R.Tensor((1, 5, 6), "float32"),
+ ) -> R.Tuple(
+ R.Tensor((1,), "int32"),
+ R.Tensor((1, 5, 6), "float32"),
+ R.Tensor((1, 5), "int32"),
+ ):
+ gv = R.vision.get_valid_counts(data, score_threshold=0.5,
id_index=0, score_index=1)
+ return gv
+
+ mod = LegalizeOps()(GVC)
+ _assert_relax_op_legalized(mod, "relax.vision.get_valid_counts")
+ tvm.ir.assert_structural_equal(
+ mod["main"].ret_struct_info,
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((1,), "int32"),
+ relax.TensorStructInfo((1, 5, 6), "float32"),
+ relax.TensorStructInfo((1, 5), "int32"),
+ ]
+ ),
+ )
+
+
+def test_nms_legalize():
+ @tvm.script.ir_module
+ class NMS:
+ @R.function
+ def main(
+ data: R.Tensor((1, 5, 6), "float32"),
+ valid_count: R.Tensor((1,), "int32"),
+ indices: R.Tensor((1, 5), "int32"),
+ ) -> R.Tuple(R.Tensor((1, 5), "int32"), R.Tensor((1, 1), "int32")):
+ gv = R.vision.non_max_suppression(
+ data,
+ valid_count,
+ indices,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+ return gv
+
+ mod = LegalizeOps()(NMS)
+ _assert_relax_op_legalized(mod, "relax.vision.non_max_suppression")
+ tvm.ir.assert_structural_equal(
+ mod["main"].ret_struct_info,
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((1, 5), "int32"),
+ relax.TensorStructInfo((1, 1), "int32"),
+ ]
+ ),
+ )
+
+
+def test_nms_legalize_return_data():
+ @tvm.script.ir_module
+ class NMS:
+ @R.function
+ def main(
+ data: R.Tensor((1, 5, 6), "float32"),
+ valid_count: R.Tensor((1,), "int32"),
+ indices: R.Tensor((1, 5), "int32"),
+ ) -> R.Tensor((1, 5, 6), "float32"):
+ gv = R.vision.non_max_suppression(
+ data,
+ valid_count,
+ indices,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=False,
+ invalid_to_bottom=True,
+ )
+ return gv
+
+ mod = LegalizeOps()(NMS)
+ _assert_relax_op_legalized(mod, "relax.vision.non_max_suppression")
+ tvm.ir.assert_structural_equal(
+ mod["main"].ret_struct_info,
+ relax.TensorStructInfo((1, 5, 6), "float32"),
+ )
+
+
[email protected]_llvm
+def test_get_valid_counts_e2e():
+ """Run get_valid_counts through legalization and compare with the numpy
reference."""
+
+ @tvm.script.ir_module
+ class GVCModule:
+ @R.function
+ def main(
+ data: R.Tensor((2, 5, 6), "float32"),
+ ) -> R.Tuple(
+ R.Tensor((2,), "int32"),
+ R.Tensor((2, 5, 6), "float32"),
+ R.Tensor((2, 5), "int32"),
+ ):
+ return R.vision.get_valid_counts(data, score_threshold=0.5,
id_index=0, score_index=1)
+
+ data_np = np.array(
+ [
+ [
+ [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+ [1.0, 0.30, 0.0, 0.0, 1.0, 1.0],
+ [-1.0, 0.90, 0.0, 0.0, 1.0, 1.0],
+ [2.0, 0.75, 2.0, 2.0, 3.0, 3.0],
+ [1.0, 0.10, 4.0, 4.0, 5.0, 5.0],
+ ],
+ [
+ [0.0, 0.55, 0.0, 0.0, 1.0, 1.0],
+ [1.0, 0.80, 1.0, 1.0, 2.0, 2.0],
+ [2.0, 0.40, 2.0, 2.0, 3.0, 3.0],
+ [3.0, 0.60, 3.0, 3.0, 4.0, 4.0],
+ [-1.0, 0.95, 5.0, 5.0, 6.0, 6.0],
+ ],
+ ],
+ dtype="float32",
+ )
+ ref_valid_count, ref_out_data, ref_out_indices =
tvm.topi.testing.get_valid_counts_python(
+ data_np, score_threshold=0.5, id_index=0, score_index=1
+ )
+
+ mod = LegalizeOps()(GVCModule)
+ exe = tvm.compile(mod, target="llvm")
+ vm = relax.VirtualMachine(exe, tvm.cpu())
+ result = vm["main"](tvm.runtime.tensor(data_np, tvm.cpu()))
+
+ tvm.testing.assert_allclose(result[0].numpy(), ref_valid_count)
+ tvm.testing.assert_allclose(result[1].numpy(), ref_out_data)
+ tvm.testing.assert_allclose(result[2].numpy(), ref_out_indices)
+
+
+def _prepare_nms_inputs(raw_data: np.ndarray):
+ """Prepare classic NMS inputs with the numpy get_valid_counts reference."""
+
+ return tvm.topi.testing.get_valid_counts_python(
+ raw_data, score_threshold=0.5, id_index=0, score_index=1
+ )
+
+
+def _run_nms_e2e(
+ data_np: np.ndarray,
+ valid_count_np: np.ndarray,
+ indices_np: np.ndarray,
+ *,
+ max_output_size: int = -1,
+ iou_threshold: float = 0.5,
+ force_suppress: bool = False,
+ top_k: int = -1,
+ coord_start: int = 2,
+ score_index: int = 1,
+ id_index: int = 0,
+ return_indices: bool = True,
+ invalid_to_bottom: bool = False,
+):
+ """Run classic NMS through legalization and VM execution."""
+
+ data_shape = tuple(int(dim) for dim in data_np.shape)
+ valid_count_shape = tuple(int(dim) for dim in valid_count_np.shape)
+ indices_shape = tuple(int(dim) for dim in indices_np.shape)
+ data = relax.Var("data", relax.TensorStructInfo(data_shape, "float32"))
+ valid_count = relax.Var("valid_count",
relax.TensorStructInfo(valid_count_shape, "int32"))
+ indices = relax.Var("indices", relax.TensorStructInfo(indices_shape,
"int32"))
+
+ bb = relax.BlockBuilder()
+ with bb.function("main", (data, valid_count, indices)):
+ result = bb.emit(
+ relax.op.vision.non_max_suppression(
+ data,
+ valid_count,
+ indices,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ force_suppress=force_suppress,
+ top_k=top_k,
+ coord_start=coord_start,
+ score_index=score_index,
+ id_index=id_index,
+ return_indices=return_indices,
+ invalid_to_bottom=invalid_to_bottom,
+ )
+ )
+ bb.emit_func_output(result)
+
+ mod = LegalizeOps()(bb.get())
+ exe = tvm.compile(mod, target="llvm")
+ vm = relax.VirtualMachine(exe, tvm.cpu())
+ return vm["main"](
+ tvm.runtime.tensor(data_np, tvm.cpu()),
+ tvm.runtime.tensor(valid_count_np, tvm.cpu()),
+ tvm.runtime.tensor(indices_np, tvm.cpu()),
+ )
+
+
[email protected]_llvm
+def test_nms_e2e_return_indices():
+ """Run classic NMS through legalization and compare with the numpy
reference."""
+
+ raw_data = np.array(
+ [
+ [
+ [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.90, 0.05, 0.05, 1.05, 1.05],
+ [1.0, 0.85, 0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.60, 2.0, 2.0, 3.0, 3.0],
+ [-1.0, 0.99, 0.0, 0.0, 1.0, 1.0],
+ ]
+ ],
+ dtype="float32",
+ )
+ valid_count_np, filtered_data_np, filtered_indices_np =
_prepare_nms_inputs(raw_data)
+ ref_indices, ref_valid_box_count =
tvm.topi.testing.non_max_suppression_python(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+ result = _run_nms_e2e(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+
+ tvm.testing.assert_allclose(result[0].numpy(), ref_indices)
+ tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count)
+
+
[email protected]_llvm
+def test_nms_e2e_return_indices_with_invalid_to_bottom():
+ """Validate that invalid_to_bottom is a no-op when returning indices."""
+
+ raw_data = np.array(
+ [
+ [
+ [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.90, 0.05, 0.05, 1.05, 1.05],
+ [1.0, 0.85, 0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.60, 2.0, 2.0, 3.0, 3.0],
+ [-1.0, 0.99, 0.0, 0.0, 1.0, 1.0],
+ ]
+ ],
+ dtype="float32",
+ )
+ valid_count_np, filtered_data_np, filtered_indices_np =
_prepare_nms_inputs(raw_data)
+ ref_indices, ref_valid_box_count =
tvm.topi.testing.non_max_suppression_python(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+ result = _run_nms_e2e(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ return_indices=True,
+ invalid_to_bottom=True,
+ )
+
+ tvm.testing.assert_allclose(result[0].numpy(), ref_indices)
+ tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count)
+
+
[email protected]_llvm
+def test_nms_e2e_top_k():
+ """Validate that classic NMS honors top_k before suppression."""
+
+ raw_data = np.array(
+ [
+ [
+ [-1.0, 0.99, 9.0, 9.0, 10.0, 10.0],
+ [0.0, 0.97, 0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.96, 2.0, 2.0, 3.0, 3.0],
+ [0.0, 0.95, 4.0, 4.0, 5.0, 5.0],
+ [1.0, 0.94, 6.0, 6.0, 7.0, 7.0],
+ [0.0, 0.20, 8.0, 8.0, 9.0, 9.0],
+ ]
+ ],
+ dtype="float32",
+ )
+ valid_count_np, filtered_data_np, filtered_indices_np =
_prepare_nms_inputs(raw_data)
+ ref_indices, ref_valid_box_count =
tvm.topi.testing.non_max_suppression_python(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=2,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+ result = _run_nms_e2e(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ top_k=2,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+
+ tvm.testing.assert_allclose(result[0].numpy(), ref_indices)
+ tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count)
+ np.testing.assert_array_equal(ref_indices, np.array([[1, 2, -1, -1, -1,
-1]], dtype="int32"))
+ np.testing.assert_array_equal(ref_valid_box_count, np.array([[2]],
dtype="int32"))
+
+
[email protected]_llvm
+def test_nms_e2e_force_suppress():
+ """Validate that force_suppress ignores class ids when suppressing
overlaps."""
+
+ raw_data = np.array(
+ [
+ [
+ [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+ [1.0, 0.90, 0.05, 0.05, 1.05, 1.05],
+ [1.0, 0.80, 2.0, 2.0, 3.0, 3.0],
+ [-1.0, 0.99, 8.0, 8.0, 9.0, 9.0],
+ ]
+ ],
+ dtype="float32",
+ )
+ valid_count_np, filtered_data_np, filtered_indices_np =
_prepare_nms_inputs(raw_data)
+ ref_indices, ref_valid_box_count =
tvm.topi.testing.non_max_suppression_python(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=True,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+ result = _run_nms_e2e(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ force_suppress=True,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+
+ tvm.testing.assert_allclose(result[0].numpy(), ref_indices)
+ tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count)
+ np.testing.assert_array_equal(ref_indices, np.array([[0, 2, -1, -1]],
dtype="int32"))
+ np.testing.assert_array_equal(ref_valid_box_count, np.array([[2]],
dtype="int32"))
+
+
[email protected]_llvm
+def test_nms_e2e_max_output_size():
+ """Validate that max_output_size truncates the kept boxes after score
sorting."""
+
+ raw_data = np.array(
+ [
+ [
+ [0.0, 0.97, 0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.95, 2.0, 2.0, 3.0, 3.0],
+ [0.0, 0.93, 4.0, 4.0, 5.0, 5.0],
+ [0.0, 0.91, 6.0, 6.0, 7.0, 7.0],
+ ]
+ ],
+ dtype="float32",
+ )
+ valid_count_np, filtered_data_np, filtered_indices_np =
_prepare_nms_inputs(raw_data)
+ ref_indices, ref_valid_box_count =
tvm.topi.testing.non_max_suppression_python(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ max_output_size=2,
+ iou_threshold=1,
+ force_suppress=False,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+ result = _run_nms_e2e(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ max_output_size=2,
+ iou_threshold=1,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+
+ tvm.testing.assert_allclose(result[0].numpy(), ref_indices)
+ tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count)
+ np.testing.assert_array_equal(ref_indices, np.array([[0, 1, -1, -1]],
dtype="int32"))
+ np.testing.assert_array_equal(ref_valid_box_count, np.array([[2]],
dtype="int32"))
+
+
[email protected]_llvm
+def test_nms_e2e_multi_batch():
+ """Validate that classic NMS processes each batch independently."""
+
+ raw_data = np.array(
+ [
+ [
+ [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.90, 0.05, 0.05, 1.05, 1.05],
+ [1.0, 0.80, 2.0, 2.0, 3.0, 3.0],
+ [-1.0, 0.99, 8.0, 8.0, 9.0, 9.0],
+ ],
+ [
+ [1.0, 0.96, 0.0, 0.0, 1.0, 1.0],
+ [2.0, 0.94, 0.04, 0.04, 1.04, 1.04],
+ [2.0, 0.88, 3.0, 3.0, 4.0, 4.0],
+ [2.0, 0.30, 6.0, 6.0, 7.0, 7.0],
+ ],
+ ],
+ dtype="float32",
+ )
+ valid_count_np, filtered_data_np, filtered_indices_np =
_prepare_nms_inputs(raw_data)
+ ref_indices, ref_valid_box_count =
tvm.topi.testing.non_max_suppression_python(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+ result = _run_nms_e2e(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+
+ tvm.testing.assert_allclose(result[0].numpy(), ref_indices)
+ tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count)
+ np.testing.assert_array_equal(
+ ref_indices,
+ np.array([[0, 2, -1, -1], [0, 1, 2, -1]], dtype="int32"),
+ )
+ np.testing.assert_array_equal(ref_valid_box_count, np.array([[2], [3]],
dtype="int32"))
+
+
[email protected]_llvm
+def test_nms_e2e_invalid_to_bottom():
+ """Validate that invalid_to_bottom compacts only boxes that remain valid
after NMS."""
+
+ raw_data = np.array(
+ [
+ [
+ [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.90, 0.05, 0.05, 1.05, 1.05],
+ [1.0, 0.85, 0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.60, 2.0, 2.0, 3.0, 3.0],
+ [-1.0, 0.99, 8.0, 8.0, 9.0, 9.0],
+ ]
+ ],
+ dtype="float32",
+ )
+ valid_count_np, filtered_data_np, filtered_indices_np =
_prepare_nms_inputs(raw_data)
+ ref_out_data = tvm.topi.testing.non_max_suppression_python(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=False,
+ invalid_to_bottom=True,
+ )
+ result = _run_nms_e2e(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ return_indices=False,
+ invalid_to_bottom=True,
+ )
+ expected_out_data = np.array(
+ [
+ [
+ [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+ [1.0, 0.85, 0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.60, 2.0, 2.0, 3.0, 3.0],
+ [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
+ [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
+ ]
+ ],
+ dtype="float32",
+ )
+
+ tvm.testing.assert_allclose(result.numpy(), ref_out_data)
+ tvm.testing.assert_allclose(result.numpy(), expected_out_data)
+
+
[email protected]_llvm
+def test_nms_e2e_return_data_without_compaction():
+ """Validate the return_indices=False path when invalid boxes stay
in-place."""
+
+ raw_data = np.array(
+ [
+ [
+ [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.90, 0.05, 0.05, 1.05, 1.05],
+ [1.0, 0.85, 0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.60, 2.0, 2.0, 3.0, 3.0],
+ [-1.0, 0.99, 8.0, 8.0, 9.0, 9.0],
+ ]
+ ],
+ dtype="float32",
+ )
+ valid_count_np, filtered_data_np, filtered_indices_np =
_prepare_nms_inputs(raw_data)
+ ref_out_data = tvm.topi.testing.non_max_suppression_python(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=False,
+ invalid_to_bottom=False,
+ )
+ result = _run_nms_e2e(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ return_indices=False,
+ invalid_to_bottom=False,
+ )
+ expected_out_data = np.array(
+ [
+ [
+ [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+ [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
+ [1.0, 0.85, 0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.60, 2.0, 2.0, 3.0, 3.0],
+ [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
+ ]
+ ],
+ dtype="float32",
+ )
+
+ tvm.testing.assert_allclose(result.numpy(), ref_out_data)
+ tvm.testing.assert_allclose(result.numpy(), expected_out_data)
+
+
[email protected]_llvm
+def test_nms_e2e_index_remap():
+ """Validate that returned indices remap from filtered order back to
original order."""
+
+ raw_data = np.array(
+ [
+ [
+ [-1.0, 0.99, 9.0, 9.0, 10.0, 10.0],
+ [0.0, 0.60, 4.0, 4.0, 5.0, 5.0],
+ [0.0, 0.10, 8.0, 8.0, 9.0, 9.0],
+ [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.90, 0.05, 0.05, 1.05, 1.05],
+ [1.0, 0.80, 2.0, 2.0, 3.0, 3.0],
+ ]
+ ],
+ dtype="float32",
+ )
+ valid_count_np, filtered_data_np, filtered_indices_np =
_prepare_nms_inputs(raw_data)
+ ref_indices, ref_valid_box_count =
tvm.topi.testing.non_max_suppression_python(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+ result = _run_nms_e2e(
+ filtered_data_np,
+ valid_count_np,
+ filtered_indices_np,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+
+ tvm.testing.assert_allclose(result[0].numpy(), ref_indices)
+ tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count)
+ np.testing.assert_array_equal(ref_indices, np.array([[3, 5, 1, -1, -1,
-1]], dtype="int32"))
+ np.testing.assert_array_equal(ref_valid_box_count, np.array([[3]],
dtype="int32"))
+
+
def test_all_class_non_max_suppression_infer_struct_info():
bb = relax.BlockBuilder()
batch_size, num_classes, num_boxes = 10, 8, 5
@@ -450,11 +1302,11 @@ def _multibox_ref_numpy(
boxes = np.zeros((B, N, 4), dtype=np.float32)
for b in range(B):
for a in range(N):
- l, t, r, br = anchor[0, a, :]
- ay = (t + br) * 0.5
- ax = (l + r) * 0.5
- ah = br - t
- aw = r - l
+ left, top, right, bottom = anchor[0, a, :]
+ ay = (top + bottom) * 0.5
+ ax = (left + right) * 0.5
+ ah = bottom - top
+ aw = right - left
ex, ey, ew, eh = loc[b, a, :]
ycenter = ey * vy * ah + ay
xcenter = ex * vx * aw + ax
diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py
b/tests/python/relax/test_tvmscript_parser_op_vision.py
index f053e36744..370b68769e 100644
--- a/tests/python/relax/test_tvmscript_parser_op_vision.py
+++ b/tests/python/relax/test_tvmscript_parser_op_vision.py
@@ -75,6 +75,138 @@ def test_all_class_non_max_suppression():
_check(foo, bb.get()["foo"])
+def test_get_valid_counts():
+ @R.function
+ def foo(
+ data: R.Tensor((10, 5, 6), "float32"),
+ ) -> R.Tuple(
+ R.Tensor((10,), "int32"),
+ R.Tensor((10, 5, 6), "float32"),
+ R.Tensor((10, 5), "int32"),
+ ):
+ gv: R.Tuple(
+ R.Tensor((10,), "int32"),
+ R.Tensor((10, 5, 6), "float32"),
+ R.Tensor((10, 5), "int32"),
+ ) = R.vision.get_valid_counts(data, score_threshold=0.5, id_index=0,
score_index=1)
+ return gv
+
+ data = relax.Var("data", R.Tensor((10, 5, 6), "float32"))
+
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [data]):
+ gv = bb.emit(
+ relax.op.vision.get_valid_counts(
+ data, score_threshold=0.5, id_index=0, score_index=1
+ )
+ )
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
+def test_non_max_suppression_return_indices():
+ @R.function
+ def foo(
+ data: R.Tensor((2, 5, 6), "float32"),
+ valid_count: R.Tensor((2,), "int32"),
+ indices: R.Tensor((2, 5), "int32"),
+ ) -> R.Tuple(R.Tensor((2, 5), "int32"), R.Tensor((2, 1), "int32")):
+ gv: R.Tuple(R.Tensor((2, 5), "int32"), R.Tensor((2, 1), "int32")) = (
+ R.vision.non_max_suppression(
+ data,
+ valid_count,
+ indices,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=3,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+ )
+ return gv
+
+ data = relax.Var("data", R.Tensor((2, 5, 6), "float32"))
+ valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+ indices = relax.Var("indices", R.Tensor((2, 5), "int32"))
+
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [data, valid_count, indices]):
+ gv = bb.emit(
+ relax.op.vision.non_max_suppression(
+ data,
+ valid_count,
+ indices,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=3,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+ )
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
+def test_non_max_suppression_return_data():
+ @R.function
+ def foo(
+ data: R.Tensor((2, 5, 6), "float32"),
+ valid_count: R.Tensor((2,), "int32"),
+ indices: R.Tensor((2, 5), "int32"),
+ ) -> R.Tensor((2, 5, 6), "float32"):
+ gv: R.Tensor((2, 5, 6), "float32") = R.vision.non_max_suppression(
+ data,
+ valid_count,
+ indices,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=False,
+ invalid_to_bottom=True,
+ )
+ return gv
+
+ data = relax.Var("data", R.Tensor((2, 5, 6), "float32"))
+ valid_count = relax.Var("valid_count", R.Tensor((2,), "int32"))
+ indices = relax.Var("indices", R.Tensor((2, 5), "int32"))
+
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [data, valid_count, indices]):
+ gv = bb.emit(
+ relax.op.vision.non_max_suppression(
+ data,
+ valid_count,
+ indices,
+ max_output_size=-1,
+ iou_threshold=0.5,
+ force_suppress=False,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=False,
+ invalid_to_bottom=True,
+ )
+ )
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
def test_multibox_transform_loc():
@R.function
def foo(