trevor-m commented on a change in pull request #8174:
URL: https://github.com/apache/tvm/pull/8174#discussion_r644337411



##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -793,6 +793,89 @@ def _impl(inputs, attr, params, mod):
     return _impl
 
 
+def convert_combined_nms_with_all_class_nms(
+    batch_size,
+    max_output_boxes_per_batch,
+    num_class,
+    boxes,
+    scores,
+    max_output_boxes_per_class,
+    iou_threshold,
+    score_threshold,
+    max_total_size,
+    clip_boxes,
+):
+    """Converts TF combined_nms using Relay all_class_max_suppression op"""
+    (selected_indices, selected_scores, num_detections,) = 
_op.vision.all_class_non_max_suppression(
+        boxes,
+        scores,
+        max_output_boxes_per_class,
+        iou_threshold,
+        score_threshold,
+        output_format="tensorflow",
+    )
+    box_range = _op.arange(
+        _op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), 
dtype="int64"
+    )
+    assert isinstance(batch_size, int), "dynamic batch size not supported yet."
+    tile_batch_reps = _op.const([batch_size, 1])
+    box_range_2d = _op.tile(box_range, tile_batch_reps)
+    valid_mask = _op.cast(
+        _op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), 
"float32"
+    )
+
+    def select_topk(do_zero_pad):
+        def true_branch():
+            arange = _op.arange(
+                _op.const(0, dtype="int64"),
+                _op.const(max_output_boxes_per_batch, dtype="int64"),
+                dtype="int64",
+            )
+            pad = _op.full(
+                _op.const(0, dtype="int64"), (max_total_size - 
max_output_boxes_per_batch,)
+            )
+            topk_indices = _op.tile(_op.concatenate([arange, pad], 0), 
tile_batch_reps)
+            nmsed_scores = _op.gather(selected_scores, 1, topk_indices)
+            nmsed_scores = nmsed_scores * valid_mask
+            return nmsed_scores, topk_indices
+
+        def false_branch():
+            if isinstance(max_output_boxes_per_class, int):
+                # Do topk on smaller input if possible
+                # TODO(masahi): use axes argument in strided slice when it 
becomes available
+                slice_mx = _op.const([-1, max_output_boxes_per_class * 
num_class], dtype="int64")
+                selected_scores_slice = _op.strided_slice(
+                    selected_scores, begin=_op.const([0, 0], dtype="int64"), 
end=slice_mx

Review comment:
       Thanks for fixing!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to