This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b14b023080 [Relax][Frontend][TFLite] Implement DETECTION_POSTPROCESS 
tflite operator (#19345)
b14b023080 is described below

commit b14b02308049f9569e2843bd937194dc90fe3171
Author: HoYi <[email protected]>
AuthorDate: Sat Apr 11 13:42:09 2026 +0800

    [Relax][Frontend][TFLite] Implement DETECTION_POSTPROCESS tflite operator 
(#19345)
    
    ## Summary
    
    - Implemented the TFLite `DETECTION_POSTPROCESS` operator conversion to
    Relax IR.
    - Wires up the previously unimplemented operator to support object
    detection post-processing workflows in Relax.
    - Relates to #18928
    
    ## Changes
    
    - **Operator Registration**: Implemented `convert_detection_postprocess`
    in `python/tvm/relax/frontend/tflite/tflite_frontend.py`.
    - **Core Logic**:
    - Integrated `multibox_transform_loc` for coordinate decoding and
    variance scaling.
    - Supported `use_regular_nms` attribute to switch between all-class NMS
    and class-agnostic NMS paths.
    - Leveraged `all_class_non_max_suppression` for efficient box filtering.
    - **Output Alignment**: Used `topk`, `gather_nd`, and `where` operators
    to ensure the output tensors (boxes, classes, scores, num_detections)
    match the TFLite specification in terms of shape and layout.
    - **Attribute Validation**: Added strict validation for required custom
    options such as `num_classes`, `max_detections`, and scaling factors.
    
    ## Validation
    
    Verified with linting and pre-commit hooks:
    
    ```bash
    # Lint check
    python -m ruff check python/tvm/relax/frontend/tflite/tflite_frontend.py
    
    # Pre-commit checks
    python -m pre_commit run --files 
python/tvm/relax/frontend/tflite/tflite_frontend.py
    ```
    
    Result:
    - **Passed**: All static checks and style guidelines are met.
---
 .../tvm/relax/frontend/tflite/tflite_flexbuffer.py |  27 ++-
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 223 ++++++++++++-----
 python/tvm/relax/transform/legalize_ops/vision.py  |  19 +-
 tests/python/relax/test_frontend_tflite.py         | 268 +++++++++++++++++++++
 4 files changed, 460 insertions(+), 77 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py 
b/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py
index dc8ce1df21..5152b6996e 100644
--- a/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py
+++ b/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py
@@ -78,12 +78,7 @@ class FlexBufferDecoder:
 
     def indirect_jump(self, offset, byte_width):
         """Helper function to read the offset value and jump"""
-        unpack_str = ""
-        if byte_width == 1:
-            unpack_str = "<B"
-        elif byte_width == 4:
-            unpack_str = "<i"
-        assert unpack_str != ""
+        unpack_str = {1: "<B", 2: "<H", 4: "<I", 8: "<Q"}[byte_width]
         back_jump = struct.unpack(unpack_str, self.buffer[offset : offset + 
byte_width])[0]
         return offset - back_jump
 
@@ -107,19 +102,26 @@ class FlexBufferDecoder:
         # Each entry in the vector can have different datatype. Each entry is 
of fixed length. The
         # format is a sequence of all values followed by a sequence of 
datatype of all values. For
         # example - (4)(3.56)(int)(float) The end here points to the start of 
the values.
+        # Each type byte contains: (type << 2) | bit_width, where bit_width 
determines actual size.
         values = list()
         for i in range(0, size):
             value_type_pos = end + size * byte_width + i
-            value_type = FlexBufferType(self.buffer[value_type_pos] >> 2)
-            value_bytes = self.buffer[end + i * byte_width : end + (i + 1) * 
byte_width]
+            value_type_packed = self.buffer[value_type_pos]
+            value_type = FlexBufferType(value_type_packed >> 2)
+            value_bit_width = BitWidth(value_type_packed & 3)
+            value_byte_width = 1 << value_bit_width
+            value_bytes = self.buffer[end + i * byte_width : end + i * 
byte_width + value_byte_width]
             if value_type == FlexBufferType.FBT_BOOL:
                 value = bool(value_bytes[0])
             elif value_type == FlexBufferType.FBT_INT:
-                value = struct.unpack("<i", value_bytes)[0]
+                fmt = {1: "<b", 2: "<h", 4: "<i", 8: "<q"}[value_byte_width]
+                value = struct.unpack(fmt, value_bytes)[0]
             elif value_type == FlexBufferType.FBT_UINT:
-                value = struct.unpack("<I", value_bytes)[0]
+                fmt = {1: "<B", 2: "<H", 4: "<I", 8: "<Q"}[value_byte_width]
+                value = struct.unpack(fmt, value_bytes)[0]
             elif value_type == FlexBufferType.FBT_FLOAT:
-                value = struct.unpack("<f", value_bytes)[0]
+                fmt = {4: "<f", 8: "<d"}[value_byte_width]
+                value = struct.unpack(fmt, value_bytes)[0]
             else:
                 raise Exception
             values.append(value)
@@ -128,7 +130,8 @@ class FlexBufferDecoder:
     def decode_map(self, end, byte_width, parent_byte_width):
         """Decodes the flexbuffer map and returns a dict"""
         mid_loc = self.indirect_jump(end, parent_byte_width)
-        map_size = struct.unpack("<i", self.buffer[mid_loc - byte_width : 
mid_loc])[0]
+        size_fmt = {1: "<b", 2: "<h", 4: "<i", 8: "<q"}[byte_width]
+        map_size = struct.unpack(size_fmt, self.buffer[mid_loc - byte_width : 
mid_loc])[0]
 
         # Find keys
         keys_offset = mid_loc - byte_width * 3
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index b344d9361a..16d5cb636b 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -2832,7 +2832,9 @@ class OperatorConverter:
             new_b_shape = [1] * max(0, rank_a - rank_b) + [int(s) for s in 
shape_b]
             max_rank = max(rank_a, rank_b)
 
-            batch_shape = [max(new_a_shape[i], new_b_shape[i]) for i in 
range(max_rank - 2)]
+            batch_shape = [
+                max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 
2)
+            ]
 
             a_broadcast = batch_shape + [int(shape_a[-2]), int(shape_a[-1])]
             b_broadcast = batch_shape + [int(shape_b[-2]), int(shape_b[-1])]
@@ -3225,16 +3227,49 @@ class OperatorConverter:
 
     def convert_detection_postprocess(self, op):
         """Convert TFLite_Detection_PostProcess"""
-        raise NotImplementedError(
-            "DETECTION_POSTPROCESS is not wired in this frontend yet: it still 
needs "
-            "Relax NMS / get_valid_counts / related vision helpers (see dead 
code below). "
-            "relax.vision.multibox_transform_loc exists; tracking: "
-            "https://github.com/apache/tvm/issues/18928";
-        )
         flexbuffer = op.CustomOptionsAsNumpy().tobytes()
         custom_options = FlexBufferDecoder(flexbuffer).decode()
 
-        use_regular_nms = "use_regular_nms" in custom_options and 
custom_options["use_regular_nms"]
+        use_regular_nms = bool(custom_options.get("use_regular_nms", False))
+
+        required_attrs = [
+            "num_classes",
+            "max_detections",
+            "detections_per_class",
+            "nms_iou_threshold",
+            "nms_score_threshold",
+            "x_scale",
+            "y_scale",
+            "w_scale",
+            "h_scale",
+        ]
+        missing_attrs = [key for key in required_attrs if key not in 
custom_options]
+        if missing_attrs:
+            raise ValueError(
+                "DETECTION_POSTPROCESS custom options miss required 
attributes: "
+                + ", ".join(missing_attrs)
+            )
+
+        num_classes = int(custom_options["num_classes"])
+        max_detections = int(custom_options["max_detections"])
+        detections_per_class = int(custom_options["detections_per_class"])
+        iou_threshold = float(custom_options["nms_iou_threshold"])
+        score_threshold = float(custom_options["nms_score_threshold"])
+        x_scale = float(custom_options["x_scale"])
+        y_scale = float(custom_options["y_scale"])
+        w_scale = float(custom_options["w_scale"])
+        h_scale = float(custom_options["h_scale"])
+
+        if num_classes <= 0:
+            raise ValueError("DETECTION_POSTPROCESS requires num_classes > 0.")
+        if max_detections <= 0:
+            raise ValueError("DETECTION_POSTPROCESS requires max_detections > 
0.")
+        if detections_per_class <= 0:
+            raise ValueError("DETECTION_POSTPROCESS requires 
detections_per_class > 0.")
+        if not 0.0 <= iou_threshold <= 1.0:
+            raise ValueError("DETECTION_POSTPROCESS requires nms_iou_threshold 
in [0, 1].")
+        if x_scale <= 0.0 or y_scale <= 0.0 or w_scale <= 0.0 or h_scale <= 
0.0:
+            raise ValueError("DETECTION_POSTPROCESS requires x/y/w/h_scale to 
be > 0.")
 
         inputs = self.get_input_tensors(op)
         assert len(inputs) == 3, "inputs length should be 3"
@@ -3296,67 +3331,139 @@ class OperatorConverter:
         # attributes for multibox_transform_loc
         multibox_transform_loc_attrs = {}
         multibox_transform_loc_attrs["clip"] = False
-        multibox_transform_loc_attrs["threshold"] = (
-            0.0 if use_regular_nms else custom_options["nms_score_threshold"]
-        )
+        multibox_transform_loc_attrs["threshold"] = 0.0 if use_regular_nms 
else score_threshold
         multibox_transform_loc_attrs["variances"] = (
-            1 / custom_options["x_scale"],
-            1 / custom_options["y_scale"],
-            1 / custom_options["w_scale"],
-            1 / custom_options["h_scale"],
+            1 / x_scale,
+            1 / y_scale,
+            1 / w_scale,
+            1 / h_scale,
         )
         multibox_transform_loc_attrs["keep_background"] = use_regular_nms
 
-        ret = relax.op.vision.multibox_transform_loc(
-            # reshape cls_pred so it can be consumed by
-            # multibox_transform_loc
-            relax.op.permute_dims(cls_pred, [0, 2, 1]),
-            loc_prob,
-            anchor_expr,
-            **multibox_transform_loc_attrs,
+        multibox_res = self.bb.emit(
+            relax.op.vision.multibox_transform_loc(
+                # reshape cls_pred so it can be consumed by
+                # multibox_transform_loc
+                relax.op.permute_dims(cls_pred, [0, 2, 1]),
+                loc_prob,
+                anchor_expr,
+                **multibox_transform_loc_attrs,
+            )
+        )
+        transformed_boxes = self.bb.emit(relax.TupleGetItem(multibox_res, 0))
+        transformed_scores = self.bb.emit(relax.TupleGetItem(multibox_res, 1))
+
+        if use_regular_nms:
+            nms_out = self.bb.emit(
+                relax.op.vision.all_class_non_max_suppression(
+                    transformed_boxes,
+                    transformed_scores,
+                    relax.const(detections_per_class, "int64"),
+                    relax.const(iou_threshold, "float32"),
+                    relax.const(score_threshold, "float32"),
+                    output_format="tensorflow",
+                )
+            )
+            selected_indices = self.bb.emit(relax.TupleGetItem(nms_out, 0))
+            selected_scores = self.bb.emit(relax.TupleGetItem(nms_out, 1))
+            num_detections = self.bb.emit(relax.TupleGetItem(nms_out, 2))
+            class_id_from_score = None
+        else:
+            topk_res = self.bb.emit(
+                relax.op.topk(transformed_scores, k=1, axis=1, 
ret_type="both", largest=True)
+            )
+            max_scores = self.bb.emit(relax.TupleGetItem(topk_res, 0))
+            class_id_from_score = self.bb.emit(relax.TupleGetItem(topk_res, 1))
+            nms_out = self.bb.emit(
+                relax.op.vision.all_class_non_max_suppression(
+                    transformed_boxes,
+                    max_scores,
+                    relax.const(max_detections, "int64"),
+                    relax.const(iou_threshold, "float32"),
+                    relax.const(score_threshold, "float32"),
+                    output_format="tensorflow",
+                )
+            )
+            selected_indices = self.bb.emit(relax.TupleGetItem(nms_out, 0))
+            selected_scores = self.bb.emit(relax.TupleGetItem(nms_out, 1))
+            num_detections = self.bb.emit(relax.TupleGetItem(nms_out, 2))
+            class_id_from_score = relax.op.squeeze(class_id_from_score, 
axis=[1])
+
+        selected_score_slots = selected_scores.struct_info.shape.values[1]
+        selected_detection_positions = relax.op.expand_dims(
+            relax.op.arange(selected_score_slots, dtype="int64"), axis=0
+        )
+        selected_valid_detection_mask = relax.op.less(
+            selected_detection_positions, relax.op.expand_dims(num_detections, 
axis=1)
+        )
+        masked_selected_scores = relax.op.where(
+            selected_valid_detection_mask,
+            selected_scores,
+            relax.const(-1.0, "float32"),
+        )
+        topk_scores_res = self.bb.emit(
+            relax.op.topk(
+                masked_selected_scores, k=max_detections, axis=1, 
ret_type="both", largest=True
+            )
+        )
+        detection_scores = self.bb.emit(relax.TupleGetItem(topk_scores_res, 0))
+        top_positions = self.bb.emit(relax.TupleGetItem(topk_scores_res, 1))
+        num_detections = relax.op.minimum(
+            num_detections, relax.const([max_detections], dtype="int64")
+        )
+        detection_positions = relax.op.expand_dims(
+            relax.op.arange(max_detections, dtype="int64"), axis=0
+        )
+        valid_detection_mask = relax.op.less(
+            detection_positions, relax.op.expand_dims(num_detections, axis=1)
+        )
+        top_positions_expanded = relax.op.expand_dims(top_positions, axis=2)
+        top_positions_for_pairs = relax.op.repeat(top_positions_expanded, 2, 
axis=2)
+        top_index_pairs = relax.op.gather_elements(
+            selected_indices, top_positions_for_pairs, axis=1
+        )
+        top_box_ids = relax.op.squeeze(
+            relax.op.strided_slice(top_index_pairs, axes=[2], begin=[1], 
end=[2]),
+            axis=[2],
+        )
+        top_box_ids_for_gather = 
relax.op.expand_dims(relax.op.astype(top_box_ids, "int64"), axis=2)
+        detection_boxes = relax.op.gather_nd(
+            transformed_boxes, top_box_ids_for_gather, batch_dims=1
         )
 
         if use_regular_nms:
-            # box coordinates need to be converted from ltrb to (ymin, xmin, 
ymax, xmax)
-            _, transformed_boxes = relax.op.split(ret[0], (2,), axis=2)
-            box_l, box_t, box_r, box_b = relax.op.split(transformed_boxes, 4, 
axis=2)
-            transformed_boxes = relax.op.concat([box_t, box_l, box_b, box_r], 
axis=2)
-
-            return relax.op.vision.regular_non_max_suppression(
-                boxes=transformed_boxes,
-                scores=cls_pred,
-                
max_detections_per_class=custom_options["detections_per_class"],
-                max_detections=custom_options["max_detections"],
-                num_classes=custom_options["num_classes"],
-                iou_threshold=custom_options["nms_iou_threshold"],
-                score_threshold=custom_options["nms_score_threshold"],
+            detection_classes = relax.op.squeeze(
+                relax.op.strided_slice(top_index_pairs, axes=[2], begin=[0], 
end=[1]),
+                axis=[2],
+            )
+            detection_classes = relax.op.astype(detection_classes, "int32")
+        else:
+            top_box_ids_for_class = relax.op.expand_dims(
+                relax.op.astype(top_box_ids, "int64"), axis=2
+            )
+            detection_classes = relax.op.gather_nd(
+                class_id_from_score, top_box_ids_for_class, batch_dims=1
             )
 
-        # attributes for non_max_suppression
-        non_max_suppression_attrs = {}
-        non_max_suppression_attrs["return_indices"] = False
-        non_max_suppression_attrs["iou_threshold"] = 
custom_options["nms_iou_threshold"]
-        non_max_suppression_attrs["force_suppress"] = True
-        non_max_suppression_attrs["top_k"] = anchor_boxes
-        non_max_suppression_attrs["max_output_size"] = 
custom_options["max_detections"]
-        non_max_suppression_attrs["invalid_to_bottom"] = False
-
-        ret = relax.op.vision.non_max_suppression(
-            ret[0], ret[1], ret[1], **non_max_suppression_attrs
+        detection_mask = relax.op.expand_dims(valid_detection_mask, axis=2)
+        detection_boxes = relax.op.where(
+            detection_mask,
+            detection_boxes,
+            relax.op.zeros((batch_size, max_detections, 4), dtype="float32"),
+        )
+        detection_classes = relax.op.where(
+            valid_detection_mask,
+            detection_classes,
+            relax.op.zeros((batch_size, max_detections), dtype="int32"),
         )
-        ret = relax.op.vision.get_valid_counts(ret, 0)
-        valid_count = ret[0]
-        # keep only the top 'max_detections' rows
-        ret = relax.op.strided_slice(
-            ret[1], [0, 0, 0], [batch_size, custom_options["max_detections"], 
6]
+        detection_scores = relax.op.where(
+            valid_detection_mask,
+            detection_scores,
+            relax.op.zeros((batch_size, max_detections), dtype="float32"),
         )
-        # the output needs some reshaping to match tflite
-        ret = relax.op.split(ret, 6, axis=2)
-        cls_ids = relax.op.reshape(ret[0], [batch_size, -1])
-        scores = relax.op.reshape(ret[1], [batch_size, -1])
-        boxes = relax.op.concat([ret[3], ret[2], ret[5], ret[4]], axis=2)
-        ret = relax.Tuple(relax.Tuple([boxes, cls_ids, scores, valid_count]), 
size=4)
-        return ret
+        detection_classes = relax.op.astype(detection_classes, "float32")
+        num_detections = relax.op.astype(num_detections, "float32")
+        return relax.Tuple([detection_boxes, detection_classes, 
detection_scores, num_detections])
 
     def convert_nms_v5(self, op):
         """Convert TFLite NonMaxSuppressionV5"""
diff --git a/python/tvm/relax/transform/legalize_ops/vision.py 
b/python/tvm/relax/transform/legalize_ops/vision.py
index 7d8586ab52..c515fc8fe8 100644
--- a/python/tvm/relax/transform/legalize_ops/vision.py
+++ b/python/tvm/relax/transform/legalize_ops/vision.py
@@ -32,11 +32,15 @@ def _all_class_non_max_suppression(block_builder: 
BlockBuilder, call: Call) -> E
 
     Returns
     -------
-    result : Tuple[Tensor, Tensor]
-        A tuple of (trimmed_indices, num_total_detections) where:
-        - trimmed_indices: Tensor of shape (num_total_detections, 3) 
containing only
-          valid detection indices (batch_id, class_id, box_id)
-        - num_total_detections: Tensor of shape (1,) with the count of valid 
detections
+    result : Expr
+        The legalized NMS result.
+
+        - For ONNX output format, returns a tuple of
+          `(trimmed_indices, num_total_detections)`, where `trimmed_indices`
+          contains only valid detection indices.
+        - For TensorFlow output format, returns the TOPI result directly to
+          preserve the `(selected_indices, selected_scores, num_detections)`
+          layout expected by the Relax op.
     """
     boxes = call.args[0]
     scores = call.args[1]
@@ -69,8 +73,9 @@ def _all_class_non_max_suppression(block_builder: 
BlockBuilder, call: Call) -> E
         output_format,
     )
 
-    # Dynamic output trimming using dynamic_strided_slice
-    # Extract selected_indices and num_total_detections from the NMS result
+    if output_format == "tensorflow":
+        return nms_result
+
     selected_indices = block_builder.emit(TupleGetItem(nms_result, 0))
     num_total_detections = block_builder.emit(TupleGetItem(nms_result, 1))
 
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 02282f3d41..c237d4db8f 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -27,6 +27,7 @@ import tflite.Model
 from tensorflow.keras import applications as keras_app
 
 import tvm
+import tvm.relax.frontend.tflite.tflite_frontend as tflite_frontend
 from tvm import relax
 from tvm.relax.frontend.tflite import from_tflite
 from tvm.script.parser import ir as I
@@ -1082,6 +1083,142 @@ def _build_nms_v5_mod(num_boxes, max_output_size, 
iou_threshold, score_threshold
     return mod, instance.func
 
 
+class _StubDetectionPostprocessTensor:
+    def __init__(self, shape, name):
+        self._shape = list(shape)
+        self._name = name
+
+    def Shape(self, index):
+        return self._shape[index]
+
+    def Name(self):
+        return self._name
+
+    def Type(self):
+        return 0
+
+
+class _StubDetectionPostprocessOp:
+    def __init__(self, custom_options):
+        self._custom_options = 
_encode_detection_postprocess_custom_options(custom_options)
+
+    def CustomOptionsAsNumpy(self):
+        return np.frombuffer(self._custom_options, dtype="uint8")
+
+
+_DETECTION_POSTPROCESS_ANCHORS = np.array(
+    [
+        [0.5, 0.5, 1.0, 1.0],
+        [0.5, 0.2, 1.0, 1.0],
+        [0.1, 0.1, 0.5, 0.5],
+        [0.8, 0.8, 0.2, 0.2],
+    ],
+    dtype="float32",
+)
+
+
+def _encode_detection_postprocess_custom_options(custom_options):
+    from flatbuffers import flexbuffers
+
+    builder = flexbuffers.Builder()
+    with builder.Map():
+        for key, value in custom_options.items():
+            if isinstance(value, bool):
+                builder.Bool(key, value)
+            elif isinstance(value, int):
+                builder.Int(key, value)
+            else:
+                builder.Float(key, float(value))
+    return bytes(builder.Finish())
+
+
+def _make_detection_postprocess_tensor_wrapper(tensor_idx, shape, name):
+    return tflite_frontend.TensorWrapper(
+        tensor_idx,
+        _StubDetectionPostprocessTensor(shape, name),
+        None,
+    )
+
+
+def _build_detection_postprocess_mod(
+    *,
+    num_classes=1,
+    max_detections=4,
+    detections_per_class=4,
+    use_regular_nms=False,
+    nms_iou_threshold=0.5,
+    nms_score_threshold=0.3,
+    x_scale=10.0,
+    y_scale=10.0,
+    w_scale=5.0,
+    h_scale=5.0,
+    batch_size=2,
+    num_anchors=4,
+    input_num_classes=None,
+):
+    custom_options = {
+        "num_classes": num_classes,
+        "max_detections": max_detections,
+        "detections_per_class": detections_per_class,
+        "nms_iou_threshold": nms_iou_threshold,
+        "nms_score_threshold": nms_score_threshold,
+        "x_scale": x_scale,
+        "y_scale": y_scale,
+        "w_scale": w_scale,
+        "h_scale": h_scale,
+        "use_regular_nms": use_regular_nms,
+    }
+    return _convert_detection_postprocess_with_options(
+        custom_options,
+        batch_size=batch_size,
+        num_anchors=num_anchors,
+        num_classes=num_classes,
+        input_num_classes=input_num_classes,
+    )
+
+
+def _convert_detection_postprocess_with_options(
+    custom_options,
+    *,
+    batch_size=2,
+    num_anchors=4,
+    num_classes=1,
+    input_num_classes=None,
+    build_module=True,
+):
+    input_num_classes = num_classes if input_num_classes is None else 
input_num_classes
+    loc = relax.Var("loc", relax.TensorStructInfo((batch_size, num_anchors, 
4), "float32"))
+    cls = relax.Var(
+        "cls", relax.TensorStructInfo((batch_size, num_anchors, 
input_num_classes), "float32")
+    )
+    inputs = [
+        _make_detection_postprocess_tensor_wrapper(0, (batch_size, 
num_anchors, 4), "loc"),
+        _make_detection_postprocess_tensor_wrapper(
+            1, (batch_size, num_anchors, input_num_classes), "cls"
+        ),
+        _make_detection_postprocess_tensor_wrapper(2, (num_anchors, 4), 
"anchors"),
+    ]
+    converter = 
tflite_frontend.OperatorConverter.__new__(tflite_frontend.OperatorConverter)
+    converter.bb = relax.BlockBuilder()
+    converter.exp_tab = tflite_frontend.ExprTable()
+    converter.get_input_tensors = lambda op: inputs
+    converter.get_expr = lambda tensor_idx: {0: loc, 1: cls}[tensor_idx]
+    converter.get_tensor_value = (
+        lambda tensor: _DETECTION_POSTPROCESS_ANCHORS if tensor.tensor_idx == 
2 else None
+    )
+    converter.get_tensor_type_str = lambda tensor_type: "float32"
+    op = _StubDetectionPostprocessOp(custom_options)
+    if not build_module:
+        return converter.convert_detection_postprocess(op)
+    bb = converter.bb
+    with bb.function("main", [loc, cls]):
+        with bb.dataflow():
+            output = converter.convert_detection_postprocess(op)
+            gv = bb.emit_output(output)
+        bb.emit_func_output(gv)
+    return bb.get()
+
+
 def _make_valid_boxes(rng, n):
     """Generate n random boxes with y1<=y2, x1<=x2 using the given RNG."""
     raw = rng.random((n, 4), dtype=np.float32)
@@ -1207,6 +1344,137 @@ def test_nms_v5_ir():
     assert f"R.Tensor(({max_output_size},)" in ir
 
 
+_DETECTION_POSTPROCESS_SMOKE_CASES = [
+    pytest.param(
+        {
+            "num_classes": 2,
+            "input_num_classes": 3,
+            "max_detections": 2,
+            "detections_per_class": 2,
+            "use_regular_nms": False,
+            "nms_iou_threshold": 0.5,
+            "nms_score_threshold": 0.5,
+            "batch_size": 1,
+            "num_anchors": 4,
+        },
+        2,
+        False,
+        id="basic_fast_nms",
+    ),
+    pytest.param(
+        {
+            "num_classes": 2,
+            "input_num_classes": 3,
+            "max_detections": 3,
+            "detections_per_class": 2,
+            "use_regular_nms": True,
+            "nms_iou_threshold": 0.45,
+            "nms_score_threshold": 0.25,
+            "batch_size": 2,
+            "num_anchors": 4,
+        },
+        1,
+        True,
+        id="regular_nms_multi_batch",
+    ),
+]
+
+
+_DETECTION_POSTPROCESS_SHAPE_CASES = [
+    pytest.param(
+        {
+            "num_classes": 2,
+            "input_num_classes": 5,
+            "max_detections": 2,
+            "detections_per_class": 2,
+            "use_regular_nms": False,
+            "nms_iou_threshold": 0.5,
+            "nms_score_threshold": 0.5,
+            "batch_size": 1,
+            "num_anchors": 4,
+        },
+        id="wider_input_classes",
+    ),
+    pytest.param(
+        {
+            "num_classes": 2,
+            "input_num_classes": 3,
+            "max_detections": 4,
+            "detections_per_class": 4,
+            "use_regular_nms": False,
+            "nms_iou_threshold": 0.5,
+            "nms_score_threshold": 0.5,
+            "batch_size": 1,
+            "num_anchors": 4,
+        },
+        id="larger_max_detections",
+    ),
+]
+
+
[email protected](
+    "build_kwargs,expected_topk_count,expected_keep_background",
+    _DETECTION_POSTPROCESS_SMOKE_CASES,
+)
+def test_detection_postprocess_smoke(
+    build_kwargs, expected_topk_count, expected_keep_background
+):
+    mod = _build_detection_postprocess_mod(**build_kwargs)
+    ir = mod.script()
+
+    assert "R.vision.multibox_transform_loc" in ir
+    assert "R.vision.all_class_non_max_suppression" in ir
+    assert 'output_format="tensorflow"' in ir
+    assert "R.where" in ir
+    assert "R.gather_elements" in ir
+    assert "R.gather_nd" in ir
+    assert ir.count("R.topk(") == expected_topk_count
+    assert f"keep_background={expected_keep_background}" in ir
+    expected_batch = build_kwargs["batch_size"]
+    expected_max_detections = build_kwargs["max_detections"]
+    tvm.ir.assert_structural_equal(
+        mod["main"].ret_struct_info,
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((expected_batch, 
expected_max_detections, 4), "float32"),
+                relax.TensorStructInfo((expected_batch, 
expected_max_detections), "float32"),
+                relax.TensorStructInfo((expected_batch, 
expected_max_detections), "float32"),
+                relax.TensorStructInfo((expected_batch,), "float32"),
+            ]
+        ),
+    )
+
+    legalized = relax.transform.LegalizeOps()(mod)
+    legalized_ir = legalized.script()
+    assert "R.vision.all_class_non_max_suppression(" not in legalized_ir
+    assert "R.call_tir(" in legalized_ir
+    tvm.ir.assert_structural_equal(legalized["main"].ret_struct_info, 
mod["main"].ret_struct_info)
+
+
[email protected]("build_kwargs", _DETECTION_POSTPROCESS_SHAPE_CASES)
+def test_detection_postprocess_shape_variations(build_kwargs):
+    mod = _build_detection_postprocess_mod(**build_kwargs)
+    batch_size = build_kwargs["batch_size"]
+    num_anchors = build_kwargs["num_anchors"]
+    input_num_classes = build_kwargs["input_num_classes"]
+    max_detections = build_kwargs["max_detections"]
+
+    tvm.ir.assert_structural_equal(
+        mod["main"].params[1].struct_info,
+        relax.TensorStructInfo((batch_size, num_anchors, input_num_classes), 
"float32"),
+    )
+    tvm.ir.assert_structural_equal(
+        mod["main"].ret_struct_info,
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((batch_size, max_detections, 4), 
"float32"),
+                relax.TensorStructInfo((batch_size, max_detections), 
"float32"),
+                relax.TensorStructInfo((batch_size, max_detections), 
"float32"),
+                relax.TensorStructInfo((batch_size,), "float32"),
+            ]
+        ),
+    )
+
 def _make_resize_expected(
     input_shape, output_size, method, coordinate_transformation_mode, 
rounding_method
 ):

Reply via email to