Laurawly commented on a change in pull request #6839:
URL: https://github.com/apache/tvm/pull/6839#discussion_r539693783
##########
File path: python/tvm/topi/cuda/nms.py
##########
@@ -97,47 +97,44 @@ def get_valid_counts_ir(
valid_count = ib.buffer_ptr(valid_count)
out = ib.buffer_ptr(out)
out_indices = ib.buffer_ptr(out_indices)
- atomic_add_return = ib.allocate(
- valid_count.dtype, (1,), name="atomic_add_return", scope="local"
- )
- one_count = tvm.tir.const(1, dtype=valid_count.dtype)
one = tvm.tir.const(1, dtype=out.dtype)
- score_threshold = tvm.ir.make_node("FloatImm", dtype="float32",
value=score_threshold)
+ if isinstance(score_threshold, float):
+ score_threshold = tvm.ir.make_node("FloatImm", dtype="float32",
value=score_threshold)
id_index = tvm.ir.make_node("IntImm", dtype="int32", value=id_index)
score_index = tvm.ir.make_node("IntImm", dtype="int32", value=score_index)
max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
- nthread_tx = max_threads
- nthread_bx = batch_size * num_anchors // max_threads + 1
- tx = te.thread_axis("threadIdx.x")
- bx = te.thread_axis("blockIdx.x")
- ib.scope_attr(tx, "thread_extent", nthread_tx)
- ib.scope_attr(bx, "thread_extent", nthread_bx)
- tid = bx * max_threads + tx
- idxd = tvm.tir.indexdiv
-
- # initialize valid_count
- with ib.if_scope(tid < batch_size):
- valid_count[tid] = 0
- with ib.if_scope(tid < batch_size * num_anchors):
- i = idxd(tid, num_anchors)
- with ib.if_scope(
- tvm.tir.all(
- data[tid * elem_length + score_index] > score_threshold,
- tvm.tir.any(id_index < 0, data[tid * elem_length + id_index]
>= 0),
- )
- ):
- atomic_add_return[0] = atomic_add(
- tvm.tir.call_intrin("handle", "tir.address_of",
valid_count[i]), one_count
- )
- with ib.for_range(0, elem_length) as k:
- out[tid * elem_length + k] = data[tid * elem_length + k]
- out_indices[tid + k] = tid + k
- with ib.else_scope():
- with ib.for_range(0, elem_length) as k:
- out[tid * elem_length + k] = -one
- out_indices[tid + k] = -one_count
-
+ with ib.new_scope():
+ nthread_tx = max_threads
+ nthread_bx = batch_size // max_threads + 1
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
+ tid = bx * max_threads + tx
+ with ib.if_scope(tid < batch_size):
+ valid_count[tid] = 0
+ i = tid
+ with ib.for_range(0, num_anchors) as j:
+ score = data[(i * num_anchors + j) * elem_length + score_index]
+ with ib.if_scope(
+ tvm.tir.all(
+ score > score_threshold,
+ tvm.tir.any(
+ id_index < 0, data[(i * num_anchors + j) *
elem_length + id_index] >= 0
+ ),
+ )
+ ):
+ with ib.for_range(0, elem_length) as k:
+ out[(i * num_anchors + valid_count[i]) * elem_length +
k] = data[
+ (i * num_anchors + j) * elem_length + k
+ ]
+ out_indices[i * num_anchors + valid_count[i]] = j
+ valid_count[i] += 1
Review comment:
Yes, that method was actually my first implementation of nms before PR
#5339. And in that PR, what I improved was only getting the correct
valid_count, and marking -1 to invalid outputs in `get_valid_count_ir`. And
after get_valid_count_ir, I do a sort (in descending order), to get the sorted
outputs as well as sorted output indicies. The second implementation is
currently in main and much faster than the previous one.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]