masahi commented on a change in pull request #7123:
URL: https://github.com/apache/tvm/pull/7123#discussion_r545429638



##########
File path: python/tvm/topi/cuda/nms.py
##########
@@ -211,28 +211,119 @@ def get_valid_indices_ir(valid_boxes, valid_count, 
valid_indices):
     valid_count = ib.buffer_ptr(valid_count)
     valid_indices = ib.buffer_ptr(valid_indices)
 
+    max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+
+    def ceil_div(a, b):
+        return tvm.tir.indexdiv(a + b - 1, b)
+
+    # Copy boxes to valid_indices
+    with ib.new_scope():
+        nthread_tx = max_threads
+        nthread_bx = ceil_div(batch_size * num_anchors, max_threads)
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * max_threads + tx
+        with ib.if_scope(tid < batch_size * num_anchors):
+            valid_indices[tid] = valid_boxes[tid]
+
+    nthread_tx = max_threads
+    nthread_bx = ceil_div(num_anchors, max_threads)
+    nthread_by = batch_size
+
+    ## The following algorithm performs parallel prefix sum to get
+    ## a tensor that can later be used to select valid indices
+    # Up Sweep of prefix sum
+    lim = tvm.tir.generic.cast(
+        tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors, 
"float64"))), "int64"
+    )
+    with ib.for_range(0, lim, dtype="int64") as l2_width:
+        width = 2 << l2_width
+
+        with ib.new_scope():
+            tx = te.thread_axis("threadIdx.x")
+            bx = te.thread_axis("blockIdx.x")
+            ib.scope_attr(tx, "thread_extent", nthread_tx)
+            ib.scope_attr(
+                bx,
+                "thread_extent",
+                tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * 
width), "int32"),
+            )
+            tid = bx * nthread_tx + tx
+
+            by = te.thread_axis("blockIdx.y")
+            ib.scope_attr(by, "thread_extent", nthread_by)
+            start = ib.allocate("int64", (1,), name="start", scope="local")
+            middle = ib.allocate("int64", (1,), name="middle", scope="local")
+            end = ib.allocate("int64", (1,), name="end", scope="local")
+            start[0] = width * tid
+            with ib.if_scope(start[0] < num_anchors):
+                middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
+                end[0] = tvm.te.min(start[0] + width, num_anchors)
+                with ib.if_scope(middle[0] < num_anchors):
+                    valid_indices[by * num_anchors + end[0] - 1] += 
valid_indices[
+                        by * num_anchors + middle[0] - 1
+                    ]
+
+    # Down Sweep of prefix sum
+    with ib.for_range(0, lim - 1, dtype="int64") as l2_width:
+        width = 2 << (lim - l2_width - 2)
+
+        with ib.new_scope():
+            tx = te.thread_axis("threadIdx.x")
+            bx = te.thread_axis("blockIdx.x")
+            ib.scope_attr(tx, "thread_extent", nthread_tx)
+            ib.scope_attr(
+                bx,
+                "thread_extent",
+                tvm.tir.generic.cast(ceil_div(num_anchors, max_threads * 
width), "int32"),
+            )
+            tid = bx * nthread_tx + tx
+
+            by = te.thread_axis("blockIdx.y")
+            ib.scope_attr(by, "thread_extent", nthread_by)
+            start = ib.allocate("int64", (1,), name="start", scope="local")
+            middle = ib.allocate("int64", (1,), name="middle", scope="local")
+            end = ib.allocate("int64", (1,), name="end", scope="local")
+            start[0] = width * tid
+            with ib.if_scope(tvm.tir.all(start[0] > 0, start[0] < 
num_anchors)):
+                middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
+                with ib.if_scope(middle[0] < num_anchors):
+                    valid_indices[by * num_anchors + middle[0] - 1] += 
valid_indices[
+                        by * num_anchors + start[0] - 1
+                    ]
+
+    ## Write Sum to valid_count
     max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
     with ib.new_scope():
         nthread_tx = max_threads
-        nthread_bx = batch_size // max_threads + 1
+        nthread_bx = ceil_div(batch_size, max_threads)

Review comment:
       What is the typical size of `batch_size` here? Do we need to launch this 
many threads (1024 minimum)?
   
   How about fusing this kernel (populating `valid_counts`) into the last 
iteration of down sweep phase above? You have plenty of threads there, and only 
some of them need to write to `valid_counts`.
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to