masahi commented on a change in pull request #6314:
URL: https://github.com/apache/incubator-tvm/pull/6314#discussion_r473891771
##########
File path: tests/python/frontend/pytorch/test_forward.py
##########
@@ -1428,6 +1428,48 @@ def test_forward_upsample3d():
verify_model(torch.nn.Upsample(scale_factor=2, mode='trilinear',
align_corners=True).eval(), inp)
+def test_forward_nms():
+ """dynamic Non-Maximum Suppression"""
+ torch.set_grad_enabled(False)
+ class NonMaxSupression1(Module):
+ def forward(self, *args):
+ return torchvision.ops.nms(args[0], args[1], 0.3)
+
+ class NonMaxSupression2(Module):
+ def forward(self, *args):
+ from torchvision.ops import nms
+ return torchvision.ops.nms(args[0], args[1], 0.5)
+
+ class NonMaxSupression3(Module):
+ def forward(self, *args):
+ from torchvision.ops import nms
+ return torchvision.ops.nms(args[0], args[1], 0.9)
+
+ # Generate random input data
+ def _gen_rand_inputs(num_boxes):
+ box_len = 4
+ boxes = torch.rand(num_boxes, box_len, dtype=torch.float) * 0.5
+ boxes[:, 2] += boxes[:, 0]
+ boxes[:, 3] += boxes[:, 1]
+ scores = torch.rand(num_boxes, dtype=torch.float)
+ return boxes, scores
+
+ in_boxes, in_scores = _gen_rand_inputs(10)
+ scripted_model1 = torch.jit.trace(NonMaxSupression1(), [in_boxes,
in_scores])
+ verify_script_model(scripted_model1, [in_boxes.shape, in_scores.shape],
+ idata=[in_boxes, in_scores])
+
+ in_boxes, in_scores = _gen_rand_inputs(100)
+ scripted_model2 = torch.jit.trace(NonMaxSupression2(), [in_boxes,
in_scores])
+ verify_script_model(scripted_model2, [in_boxes.shape, in_scores.shape],
+ idata=[in_boxes, in_scores])
+
+ in_boxes, in_scores = _gen_rand_inputs(500)
+ scripted_model3 = torch.jit.trace(NonMaxSupression3(), [in_boxes,
in_scores])
+ verify_script_model(scripted_model3, [in_boxes.shape, in_scores.shape],
+ idata=[in_boxes, in_scores])
+
Review comment:
please clean up like this
```
for num_boxes, thres in [(10, 0.3), (100, 0.9), (500, 0.9)]:
in_boxes, in_scores = _gen_rand_inputs(num_boxes)
traced_model = torch.jit.trace(NonMaxSupression(thres), [in_boxes,
in_scores])
verify_script_model(traced_model, [in_boxes.shape, in_scores.shape],
idata=[in_boxes, in_scores])
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]