masahi commented on a change in pull request #7123:
URL: https://github.com/apache/tvm/pull/7123#discussion_r546360307



##########
File path: python/tvm/topi/cuda/nms.py
##########
@@ -211,31 +215,119 @@ def get_valid_indices_ir(valid_boxes, valid_count, 
valid_indices):
     valid_indices = ib.buffer_ptr(valid_indices)
 
     max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+
+    # Copy boxes to valid_indices
     with ib.new_scope():
         nthread_tx = max_threads
-        nthread_bx = batch_size // max_threads + 1
+        nthread_bx = ceil_div(num_anchors, max_threads)
+        nthread_by = batch_size
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        by = te.thread_axis("blockIdx.y")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        ib.scope_attr(by, "thread_extent", nthread_by)
+        tid = bx * nthread_tx + tx
+        with ib.if_scope(tid < num_anchors):
+            valid_indices[by, tid] = valid_boxes[by, tid]
+
+    nthread_tx = max_threads
+    nthread_bx = ceil_div(num_anchors, max_threads)
+    nthread_by = batch_size
+
+    ## The following algorithm performs parallel exclusive scan to get
+    ## a tensor that can later be used to select valid indices
+    # Up Sweep of exclusive scan
+    lim = tvm.tir.generic.cast(
+        tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, 
"float64"))), "int64"
+    )
+    with ib.for_range(0, lim, dtype="int64") as l2_width:
+        width = 2 << l2_width
+
+        with ib.new_scope():
+            tx = te.thread_axis("threadIdx.x")
+            bx = te.thread_axis("blockIdx.x")
+            ib.scope_attr(tx, "thread_extent", nthread_tx)
+            ib.scope_attr(
+                bx,
+                "thread_extent",
+                tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * 
width), "int32"),
+            )
+            tid = bx * nthread_tx + tx
+
+            by = te.thread_axis("blockIdx.y")
+            ib.scope_attr(by, "thread_extent", nthread_by)
+            start = ib.allocate("int64", (1,), name="start", scope="local")
+            middle = ib.allocate("int64", (1,), name="middle", scope="local")
+            end = ib.allocate("int64", (1,), name="end", scope="local")
+            start[0] = width * tid
+            with ib.if_scope(start[0] < num_anchors):
+                middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
+                end[0] = tvm.te.min(start[0] + width, num_anchors)
+                with ib.if_scope(middle[0] < num_anchors):
+                    valid_indices[by * num_anchors + end[0] - 1] += 
valid_indices[
+                        by * num_anchors + middle[0] - 1
+                    ]
+
+    # Down Sweep of exclusive scan
+    with ib.new_scope():
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", batch_size)
+        with ib.if_scope(bx < batch_size):
+            valid_indices[(bx + 1) * num_anchors - 1] = 0

Review comment:
       Since the reduction stored in `valid_indices[(bx + 1) * num_anchors - 
1]` is needed later anyway, we should do
   
   ```
   valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1]
   ```
   
   right here before we overwrite it. We can then remove the last kernel at 
L311 as a bonus.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to