kevinthesun commented on a change in pull request #7137:
URL: https://github.com/apache/tvm/pull/7137#discussion_r546923370
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1857,16 +1857,18 @@ def nms(self, inputs, input_types):
scores = inputs[1]
iou_threshold = inputs[2]
+ num_boxes = _op.shape_of(scores)
+
+ # TVM NMS assumes score > 0
+ scores = scores - _op.min(scores) + _op.const(1.0)
# Generate data with shape (1, num_anchors, 5)
scores = AttrCvt(op_name="expand_dims", extras={"axis": -1,
"num_newaxis": 1})([scores], {})
-
- # Prepare input data for get_valid_counts
data = _op.concatenate([scores, boxes], -1)
data = _op.expand_dims(data, 0, 1)
- # Leverage get_valid_counts to sort the data and clear invalid boxes
- ct, data, indices = get_relay_op("get_valid_counts")(
- data, score_threshold=-1.0, id_index=-1, score_index=0
- )
+ # PyTorch NMS doesn't have score_threshold, so no need to run
get_valid_count
Review comment:
torchvision nms doesn't filter out invalid boxes before nms, which can
be super slow. Filtering out negative score boxes should have no affect to
results. Probably we can discuss more about what we can get from this change.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]