mbrookhart commented on a change in pull request #6839:
URL: https://github.com/apache/incubator-tvm/pull/6839#discussion_r522503532



##########
File path: python/tvm/topi/cuda/nms.py
##########
@@ -519,14 +557,90 @@ def non_max_suppression(
             coord_start,
             id_index,
             score_index,
+            return_indices,
         ),
         dtype=[data.dtype, "int32"],
-        in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
+        in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf],
         name="nms",
         tag="nms",
     )
-    # TODO(yongwww): Update cuda nms to be consistent with cpu version
     if return_indices:
-        return box_indices
+        out_shape = box_indices.shape
+        valid_box_count_shape = [box_indices.shape[0], 1]
+        valid_box_count = tvm.tir.decl_buffer(valid_box_count_shape, "int32", 
"valid_box_count")
+        output = tvm.tir.decl_buffer(box_indices.shape, "int32", "output")
+        return te.extern(
+            [out_shape, valid_box_count_shape],
+            [box_indices],
+            lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0], 
outs[1]),
+            dtype="int32",
+            out_buffers=[output, valid_box_count],
+            name="rearrange_indices_out_gpu",
+            tag="rearrange_indices_out_gpu",
+        )
 
     return out
+
+
+def rearrange_indices_out_ir(data, output, valid_box_count):
+    """Hybrid routine to rearrange nms output to
+    move all valid entries to top.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor or numpy NDArray
+        NMS output. 3-D tensor with shape
+        [batch_size, num_anchors, 6] or
+        [batch_size, num_anchors, 5], or 2-D
+        tensor with shape [batch_size, num_anchors].
+
+    one: tvm.tir.const
+        Constant one with the same dtype as data.
+
+    batch_size: tvm.tir.IntImm or tvm.tir.Var
+        Batch size. We need to pass it in since hybrid script doesn't support
+        binding variable to symbolic dim.
+
+    num_anchors: tvm.tir.IntImm or tvm.tir.Var
+        Number of anchors.
+
+    Returns
+    -------
+    output : tvm.te.Tensor or numpy NDArray
+        2-D tensor with shape [batch_size, num_anchors].
+
+    valid_box_count : tvm.te.Tensor or numpy NDArray
+        Tensor with shape [batch_size, 1], indicates
+        the valid number of boxes.
+    """
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+
+    ib = tvm.tir.ir_builder.create()
+
+    data = ib.buffer_ptr(data)
+    valid_box_count = ib.buffer_ptr(valid_box_count)
+    output = ib.buffer_ptr(output)
+
+    with ib.new_scope():
+        i = te.thread_axis("blockIdx.x")
+        ib.scope_attr(i, "thread_extent", batch_size)
+        valid_idx = ib.allocate("int32", (1), name="valid_idx", scope="local")
+        valid_idx[0] = 0
+        with ib.for_range(0, num_anchors, name="j") as j:
+            with ib.if_scope(data[i, j] >= 0):
+                with ib.if_scope(data[i, j] > num_anchors):
+                    output[i, valid_idx[0]] = 0
+                    valid_idx[0] = valid_idx[0] + 1
+                with ib.else_scope():
+                    output[i, valid_idx[0]] = data[i, j]
+                    valid_idx[0] = valid_idx[0] + 1
+            with ib.else_scope():
+                with ib.if_scope(data[i, j] < -num_anchors):
+                    output[i, valid_idx[0]] = 0
+                    valid_idx[0] = valid_idx[0] + 1
+            with ib.if_scope(j >= valid_idx[0]):
+                output[i, j] = -1
+        valid_box_count[i, 0] = valid_idx[0]
+
+    return ib.get()

Review comment:
       I fixed a number of bugs in #6906 that allow me to run SSD-RN50 on CPU, 
but with those fixes, I'm hitting the isuse of cuda topk not supporting dynamic 
shapes, and the output of NMS is intrinsically dynamic. I'm back to trying to 
solve that problem before I can give you perf numbers.
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to