masahi commented on a change in pull request #7123:
URL: https://github.com/apache/tvm/pull/7123#discussion_r546360307
##########
File path: python/tvm/topi/cuda/nms.py
##########
@@ -211,31 +215,119 @@ def get_valid_indices_ir(valid_boxes, valid_count,
valid_indices):
valid_indices = ib.buffer_ptr(valid_indices)
max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+
+ # Copy boxes to valid_indices
with ib.new_scope():
nthread_tx = max_threads
- nthread_bx = batch_size // max_threads + 1
+ nthread_bx = ceil_div(num_anchors, max_threads)
+ nthread_by = batch_size
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ by = te.thread_axis("blockIdx.y")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ ib.scope_attr(by, "thread_extent", nthread_by)
+ tid = bx * nthread_tx + tx
+ with ib.if_scope(tid < num_anchors):
+ valid_indices[by, tid] = valid_boxes[by, tid]
+
+ nthread_tx = max_threads
+ nthread_bx = ceil_div(num_anchors, max_threads)
+ nthread_by = batch_size
+
+ ## The following algorithm performs parallel exclusive scan to get
+ ## a tensor that can later be used to select valid indices
+ # Up Sweep of exclusive scan
+ lim = tvm.tir.generic.cast(
+ tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(num_anchors,
"float64"))), "int64"
+ )
+ with ib.for_range(0, lim, dtype="int64") as l2_width:
+ width = 2 << l2_width
+
+ with ib.new_scope():
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(
+ bx,
+ "thread_extent",
+ tvm.tir.generic.cast(ceil_div(num_anchors, max_threads *
width), "int32"),
+ )
+ tid = bx * nthread_tx + tx
+
+ by = te.thread_axis("blockIdx.y")
+ ib.scope_attr(by, "thread_extent", nthread_by)
+ start = ib.allocate("int64", (1,), name="start", scope="local")
+ middle = ib.allocate("int64", (1,), name="middle", scope="local")
+ end = ib.allocate("int64", (1,), name="end", scope="local")
+ start[0] = width * tid
+ with ib.if_scope(start[0] < num_anchors):
+ middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
+ end[0] = tvm.te.min(start[0] + width, num_anchors)
+ with ib.if_scope(middle[0] < num_anchors):
+ valid_indices[by * num_anchors + end[0] - 1] +=
valid_indices[
+ by * num_anchors + middle[0] - 1
+ ]
+
+ # Down Sweep of exclusive scan
+ with ib.new_scope():
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(bx, "thread_extent", batch_size)
+ with ib.if_scope(bx < batch_size):
+ valid_indices[(bx + 1) * num_anchors - 1] = 0
Review comment:
Since the reduction stored in `valid_indices[(bx + 1) * num_anchors -
1]` is needed later anyway, we should do
```
valid_count[bx] = valid_indices[(bx + 1) * num_anchors - 1]
```
right here before we overwrite it. We can then remove the last kernel at
L311 as a bonus.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]