trevor-m commented on a change in pull request #8174:
URL: https://github.com/apache/tvm/pull/8174#discussion_r644321057
##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -793,6 +793,89 @@ def _impl(inputs, attr, params, mod):
return _impl
+def convert_combined_nms_with_all_class_nms(
+ batch_size,
+ max_output_boxes_per_batch,
+ num_class,
+ boxes,
+ scores,
+ max_output_boxes_per_class,
+ iou_threshold,
+ score_threshold,
+ max_total_size,
+ clip_boxes,
+):
+ """Converts TF combined_nms using Relay all_class_max_suppression op"""
+ (selected_indices, selected_scores, num_detections,) =
_op.vision.all_class_non_max_suppression(
+ boxes,
+ scores,
+ max_output_boxes_per_class,
+ iou_threshold,
+ score_threshold,
+ output_format="tensorflow",
+ )
+ box_range = _op.arange(
+ _op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"),
dtype="int64"
+ )
+ assert isinstance(batch_size, int), "dynamic batch size not supported yet."
+ tile_batch_reps = _op.const([batch_size, 1])
+ box_range_2d = _op.tile(box_range, tile_batch_reps)
+ valid_mask = _op.cast(
+ _op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)),
"float32"
+ )
+
+ def select_topk(do_zero_pad):
+ def true_branch():
+ arange = _op.arange(
+ _op.const(0, dtype="int64"),
+ _op.const(max_output_boxes_per_batch, dtype="int64"),
+ dtype="int64",
+ )
+ pad = _op.full(
+ _op.const(0, dtype="int64"), (max_total_size -
max_output_boxes_per_batch,)
+ )
+ topk_indices = _op.tile(_op.concatenate([arange, pad], 0),
tile_batch_reps)
+ nmsed_scores = _op.gather(selected_scores, 1, topk_indices)
+ nmsed_scores = nmsed_scores * valid_mask
+ return nmsed_scores, topk_indices
+
+ def false_branch():
+ if isinstance(max_output_boxes_per_class, int):
+ # Do topk on smaller input if possible
+ # TODO(masahi): use axes argument in strided slice when it
becomes available
+ slice_mx = _op.const([-1, max_output_boxes_per_class *
num_class], dtype="int64")
+ selected_scores_slice = _op.strided_slice(
+ selected_scores, begin=_op.const([0, 0], dtype="int64"),
end=slice_mx
Review comment:
Need to use `slice_mode="size"` for this strided_slice, otherwise using
-1 for end is returning a tensor with batch size of 0.
```
%988 = vision.all_class_non_max_suppression(%986, %987, 100 /* ty=int32 */,
0.6f /* ty=float32 */, 0.3f /* ty=float32 */,
meta[relay.attrs.AllClassNonMaximumSuppressionAttrs][0]) /* ty=(Tensor[(1,
172530, 2), int64], Tensor[(1, 172530), float32], Tensor[(1), int64]) */;
%989 = %988.1;
%990 = strided_slice(%989, begin=[0, 0], end=[-1, 9000], strides=[1]) /*
ty=Tensor[(0, 9000), float32] */;
```
##########
File path: python/tvm/topi/cuda/nms.py
##########
@@ -1012,39 +1079,69 @@ def all_class_non_max_suppression(
score_threshold : float or tvm.te.Tensor, optional
Score threshold to filter out low score boxes early
+ output_format : str, optional
+ "onnx" or "tensorflow", see below
+
Returns
-------
- out : [tvm.te.Tensor, tvm.te.Tensor]
- The output is two tensors, the first is `indices` of size
+ out : list of tvm.te.Tensor
+ If `output_format` is "onnx", the output is two tensors. The first is
`indices` of size
`(batch_size * num_class* num_boxes , 3)` and the second is a scalar
tensor
`num_total_detection` of shape `(1,)` representing the total number of
selected
- boxes. Rows of `indices` are ordered such that selected boxes from
batch 0, class 0 come
+ boxes. The three values in `indices` encode batch, class, and box
indices.
+ Rows of `indices` are ordered such that selected boxes from batch 0,
class 0 come
first, in descending of scores, followed by boxes from batch 0, class
1 etc. Out of
`batch_size * num_class* num_boxes` rows of indices, only the first
`num_total_detection`
rows are valid.
+
+ If `output_format` is "tensorflow", the output is three tensors, the
first
+ is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the
second is `scores` of
+ size `(batch_size, num_class * num_boxes)`, and the third is
`num_total_detection` of size
+ `(batch_size,)` representing the total number of selected boxes per
batch. The two values
+ in `indices` encode class and box indices. Of num_class * num_boxes
boxes in `indices` at
+ batch b, only the first `num_total_detection[b]` entries are valid.
The second axis of
+ `indices` and `scores` are sorted within each class by box scores, but
not across classes.
+ So the box indices and scores for the class 0 come first in a sorted
order, followed by
+ the class 1 etc.
"""
batch, num_class, num_boxes = scores.shape
scores = reshape(scores, (batch * num_class, num_boxes))
sorted_scores, sorted_indices = _dispatch_sort(scores, ret_type="both")
valid_count = _get_valid_box_count(sorted_scores, score_threshold)
- selected_indices, num_detections = run_all_class_nms(
+ selected_indices, selected_scores, num_detections = run_all_class_nms(
boxes,
sorted_scores,
sorted_indices,
valid_count,
max_output_boxes_per_class,
iou_threshold,
_nms_loop,
+ return_scores=(output_format == "tensorflow"),
)
+ if output_format == "onnx":
+ row_offsets, num_total_detections = exclusive_scan(
+ num_detections, return_reduction=True, output_dtype="int64"
+ )
+ selected_indices = collect_selected_indices(
+ num_class, selected_indices, num_detections, row_offsets,
_collect_selected_indices_ir
+ )
+ return [selected_indices, num_total_detections]
+
+ num_detections_per_batch = reshape(num_detections, (batch, num_class))
row_offsets, num_total_detections = exclusive_scan(
- num_detections, return_reduction=True, output_dtype="int64"
+ num_detections_per_batch, return_reduction=True, output_dtype="int64",
axis=1
Review comment:
Because of the int64 type, I had to set this line
https://github.com/apache/tvm/blob/main/python/tvm/topi/cuda/scan.py#L234
to `reduction[tid] = cast(0, "int64")`
Otherwise I got the error: `data type does not match content type int32 vs
int64`
##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -793,6 +793,89 @@ def _impl(inputs, attr, params, mod):
return _impl
+def convert_combined_nms_with_all_class_nms(
+ batch_size,
+ max_output_boxes_per_batch,
+ num_class,
+ boxes,
+ scores,
+ max_output_boxes_per_class,
+ iou_threshold,
+ score_threshold,
+ max_total_size,
+ clip_boxes,
+):
+ """Converts TF combined_nms using Relay all_class_max_suppression op"""
+ (selected_indices, selected_scores, num_detections,) =
_op.vision.all_class_non_max_suppression(
+ boxes,
+ scores,
+ max_output_boxes_per_class,
+ iou_threshold,
+ score_threshold,
+ output_format="tensorflow",
+ )
+ box_range = _op.arange(
+ _op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"),
dtype="int64"
+ )
+ assert isinstance(batch_size, int), "dynamic batch size not supported yet."
+ tile_batch_reps = _op.const([batch_size, 1])
+ box_range_2d = _op.tile(box_range, tile_batch_reps)
+ valid_mask = _op.cast(
+ _op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)),
"float32"
+ )
+
+ def select_topk(do_zero_pad):
+ def true_branch():
+ arange = _op.arange(
+ _op.const(0, dtype="int64"),
+ _op.const(max_output_boxes_per_batch, dtype="int64"),
+ dtype="int64",
+ )
+ pad = _op.full(
+ _op.const(0, dtype="int64"), (max_total_size -
max_output_boxes_per_batch,)
+ )
+ topk_indices = _op.tile(_op.concatenate([arange, pad], 0),
tile_batch_reps)
+ nmsed_scores = _op.gather(selected_scores, 1, topk_indices)
+ nmsed_scores = nmsed_scores * valid_mask
+ return nmsed_scores, topk_indices
+
+ def false_branch():
+ if isinstance(max_output_boxes_per_class, int):
+ # Do topk on smaller input if possible
+ # TODO(masahi): use axes argument in strided slice when it
becomes available
+ slice_mx = _op.const([-1, max_output_boxes_per_class *
num_class], dtype="int64")
+ selected_scores_slice = _op.strided_slice(
+ selected_scores, begin=_op.const([0, 0], dtype="int64"),
end=slice_mx
Review comment:
Thanks for fixing!
##########
File path: python/tvm/topi/cuda/nms.py
##########
@@ -1012,39 +1079,69 @@ def all_class_non_max_suppression(
score_threshold : float or tvm.te.Tensor, optional
Score threshold to filter out low score boxes early
+ output_format : str, optional
+ "onnx" or "tensorflow", see below
+
Returns
-------
- out : [tvm.te.Tensor, tvm.te.Tensor]
- The output is two tensors, the first is `indices` of size
+ out : list of tvm.te.Tensor
+ If `output_format` is "onnx", the output is two tensors. The first is
`indices` of size
`(batch_size * num_class* num_boxes , 3)` and the second is a scalar
tensor
`num_total_detection` of shape `(1,)` representing the total number of
selected
- boxes. Rows of `indices` are ordered such that selected boxes from
batch 0, class 0 come
+ boxes. The three values in `indices` encode batch, class, and box
indices.
+ Rows of `indices` are ordered such that selected boxes from batch 0,
class 0 come
first, in descending of scores, followed by boxes from batch 0, class
1 etc. Out of
`batch_size * num_class* num_boxes` rows of indices, only the first
`num_total_detection`
rows are valid.
+
+ If `output_format` is "tensorflow", the output is three tensors, the
first
+ is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the
second is `scores` of
+ size `(batch_size, num_class * num_boxes)`, and the third is
`num_total_detection` of size
+ `(batch_size,)` representing the total number of selected boxes per
batch. The two values
+ in `indices` encode class and box indices. Of num_class * num_boxes
boxes in `indices` at
+ batch b, only the first `num_total_detection[b]` entries are valid.
The second axis of
+ `indices` and `scores` are sorted within each class by box scores, but
not across classes.
+ So the box indices and scores for the class 0 come first in a sorted
order, followed by
+ the class 1 etc.
"""
batch, num_class, num_boxes = scores.shape
scores = reshape(scores, (batch * num_class, num_boxes))
sorted_scores, sorted_indices = _dispatch_sort(scores, ret_type="both")
valid_count = _get_valid_box_count(sorted_scores, score_threshold)
- selected_indices, num_detections = run_all_class_nms(
+ selected_indices, selected_scores, num_detections = run_all_class_nms(
boxes,
sorted_scores,
sorted_indices,
valid_count,
max_output_boxes_per_class,
iou_threshold,
_nms_loop,
+ return_scores=(output_format == "tensorflow"),
)
+ if output_format == "onnx":
+ row_offsets, num_total_detections = exclusive_scan(
+ num_detections, return_reduction=True, output_dtype="int64"
+ )
+ selected_indices = collect_selected_indices(
+ num_class, selected_indices, num_detections, row_offsets,
_collect_selected_indices_ir
+ )
+ return [selected_indices, num_total_detections]
+
+ num_detections_per_batch = reshape(num_detections, (batch, num_class))
row_offsets, num_total_detections = exclusive_scan(
- num_detections, return_reduction=True, output_dtype="int64"
+ num_detections_per_batch, return_reduction=True, output_dtype="int64",
axis=1
Review comment:
Nice, thanks!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]