masahi commented on a change in pull request #7136:
URL: https://github.com/apache/tvm/pull/7136#discussion_r546502087



##########
File path: python/tvm/topi/cuda/nms.py
##########
@@ -523,132 +473,112 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
     max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
 
     with ib.new_scope():
+        nthread_tx = max_threads
+        # num_anchors can be zero
+        nthread_bx = tvm.tir.max(1, ceil_div(num_anchors, max_threads))
         nthread_by = batch_size
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
         by = te.thread_axis("blockIdx.y")
         ib.scope_attr(by, "thread_extent", nthread_by)
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
         i = by
         base_idx = i * num_anchors * box_data_length
         with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
             # Reorder output
             nkeep = if_then_else(
                 tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, 
valid_count[i]
             )
-            with ib.for_range(0, nkeep) as j:
+            j = bx * max_threads + tx
+            with ib.if_scope(j < num_anchors):
+                box_indices[i * num_anchors + j] = -1
+            with ib.if_scope(j < nkeep):
+                # Fill in out with sorted boxes
                 with ib.for_range(0, box_data_length) as k:
                     out[(base_idx + j * box_data_length + k)] = data[
                         (base_idx + sorted_index[i * num_anchors + j] * 
box_data_length + k)
                     ]
-                box_indices[i * num_anchors + j] = sorted_index[i * 
num_anchors + j]
-            with ib.if_scope(tvm.tir.all(top_k > 0, top_k < valid_count[i])):
-                with ib.for_range(0, valid_count[i] - nkeep) as j:
+            with ib.else_scope():
+                # Indices > nkeep are discarded
+                with ib.if_scope(j < num_anchors):
                     with ib.for_range(0, box_data_length) as k:
-                        out[(base_idx + (j + nkeep) * box_data_length + k)] = 
-1.0
-                    box_indices[i * num_anchors + (j + nkeep)] = -1
+                        out[(base_idx + j * box_data_length + k)] = -1.0
+        with ib.else_scope():
+            with ib.if_scope(j < valid_count[i]):
+                with ib.for_range(0, box_data_length) as k:
+                    offset = base_idx + j * box_data_length + k
+                    out[offset] = data[offset]
+                box_indices[i * num_anchors + j] = j
+
     with ib.new_scope():
         nthread_by = batch_size
         by = te.thread_axis("blockIdx.y")
         ib.scope_attr(by, "thread_extent", nthread_by)
         i = by
         base_idx = i * num_anchors * box_data_length
+        num_valid_boxes_local = ib.allocate(
+            "int32", (1,), name="num_valid_boxes_local", scope="local"
+        )
+        num_valid_boxes_local[0] = 0
+
+        def nms_inner_loop(ib, j):
+            offset_j = j * box_data_length
+
+            with ib.for_range(0, j) as k:
+                offset_k = k * box_data_length
+
+                with ib.if_scope(
+                    tvm.tir.all(
+                        out[base_idx + offset_j + score_index] > -1.0,  # if 
already surpressed

Review comment:
       I tried this, is that what you are suggesting? I thought these two 
snippets are equivalent but actually the one below makes slower. I don't know 
what's happening, but I'm ok with the current implementation. The triangle loop 
needs rework later anyway.
   
   ```
   # check if box j is already surpressed                                       
                                                                                
   with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0):      
       with ib.if_scope(                                                        
                                                                                
           tvm.tir.all(                                                         
                                                                                
               out[base_idx + offset_k + score_index] > 0,               
               tvm.tir.any(id_index < 0, out[base_idx + offset_k + id_index] >= 
0),
               tvm.tir.any(
                   force_suppress > 0,
                   id_index < 0,
                   out[base_idx + offset_k + id_index]
                   == out[base_idx + offset_j + id_index],
               ),
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to