This is an automated email from the ASF dual-hosted git repository.

laurawly pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 82942fb  [TOPI] Simplify GPU NMS IR and optimize a bit (#7136)
82942fb is described below

commit 82942fb33fd6e3572897c815af16905c4f75c2a4
Author: masahi <[email protected]>
AuthorDate: Tue Dec 22 02:17:26 2020 +0900

    [TOPI] Simplify GPU NMS IR and optimize a bit (#7136)
    
    * remove get_valid_counts from pytorch nms
    
    * fix pytorch nms for negative score
    
    * merge reset by -1
    
    * move max_out_size handling to triangle loop
    
    * update torch nms test
    
    * fuse the last two kernels
    
    * parallelize the first kernel
    
    * merge first and last kernel
    
    * remove unnecessary cases
    
    * fix typo
    
    * revert pytorch frontend change
    
    * fuse rearrange step with triangle loop
    
    * fix max_output_size handling
    
    * check if already surpressed
    
    * fix topi vision test by wrapping tir const around int argument
    
    * fix for num anchors = 0 case
    
    * fix missing zero init of num valid boxes when the input is empty
    
    * add some comments and missing doc
    
    * typo fix
    
    * add a guard against zero dim grid / thread block inside ir_buidlder
    
    * typo fix
    
    * trigger CI
---
 python/tvm/tir/ir_builder.py |   4 +
 python/tvm/topi/cuda/nms.py  | 279 +++++++++++++++----------------------------
 2 files changed, 102 insertions(+), 181 deletions(-)

diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index 75c5c29..6dcc858 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -21,6 +21,7 @@ from tvm.ir import container as _container, PointerType, 
PrimType
 
 from . import stmt as _stmt
 from . import expr as _expr
+from . import op
 
 
 class WithScope(object):
@@ -200,6 +201,9 @@ class IRBuilder(object):
             node = _expr.StringImm(node)
         if isinstance(value, string_types):
             value = _expr.StringImm(value)
+        # thread_extent could be zero for dynamic workloads
+        if attr_key == "thread_extent":
+            value = op.max(1, value)
         self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))
 
     def for_range(self, begin, end, name="i", dtype="int32", 
for_type="serial"):
diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index 2733970..cea287e 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -51,68 +51,8 @@ def atomic_add(x, y):
     return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y)
 
 
-def rearrange_indices_out_ir(data, output, valid_box_count):
-    """Hybrid routine to rearrange nms output to
-    move all valid entries to top.
-
-    Parameters
-    ----------
-    data : tvm.te.Tensor or numpy NDArray
-        NMS output. 3-D tensor with shape
-        [batch_size, num_anchors, 6] or
-        [batch_size, num_anchors, 5], or 2-D
-        tensor with shape [batch_size, num_anchors].
-
-    one: tvm.tir.const
-        Constant one with the same dtype as data.
-
-    batch_size: tvm.tir.IntImm or tvm.tir.Var
-        Batch size. We need to pass it in since hybrid script doesn't support
-        binding variable to symbolic dim.
-
-    num_anchors: tvm.tir.IntImm or tvm.tir.Var
-        Number of anchors.
-
-    Returns
-    -------
-    output : tvm.te.Tensor or numpy NDArray
-        2-D tensor with shape [batch_size, num_anchors].
-
-    valid_box_count : tvm.te.Tensor or numpy NDArray
-        Tensor with shape [batch_size, 1], indicates
-        the valid number of boxes.
-    """
-    batch_size = data.shape[0]
-    num_anchors = data.shape[1]
-
-    ib = tvm.tir.ir_builder.create()
-
-    data = ib.buffer_ptr(data)
-    valid_box_count = ib.buffer_ptr(valid_box_count)
-    output = ib.buffer_ptr(output)
-
-    with ib.new_scope():
-        i = te.thread_axis("blockIdx.x")
-        ib.scope_attr(i, "thread_extent", batch_size)
-        valid_idx = ib.allocate("int32", (1,), name="valid_idx", scope="local")
-        valid_idx[0] = 0
-        with ib.for_range(0, num_anchors, name="j") as j:
-            with ib.if_scope(data[i, j] >= 0):
-                with ib.if_scope(data[i, j] > num_anchors):
-                    output[i, valid_idx[0]] = 0
-                    valid_idx[0] = valid_idx[0] + 1
-                with ib.else_scope():
-                    output[i, valid_idx[0]] = data[i, j]
-                    valid_idx[0] = valid_idx[0] + 1
-            with ib.else_scope():
-                with ib.if_scope(data[i, j] < -num_anchors):
-                    output[i, valid_idx[0]] = 0
-                    valid_idx[0] = valid_idx[0] + 1
-            with ib.if_scope(j >= valid_idx[0]):
-                output[i, j] = -1
-        valid_box_count[i, 0] = valid_idx[0]
-
-    return ib.get()
+def ceil_div(a, b):
+    return tvm.tir.indexdiv(a + b - 1, b)
 
 
 def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, 
score_index):
@@ -400,6 +340,7 @@ def nms_ir(
     indices,
     out,
     box_indices,
+    num_valid_boxes,
     max_output_size,
     iou_threshold,
     force_suppress,
@@ -430,7 +371,15 @@ def nms_ir(
         is not used before non_max_suppression.
 
     out : Buffer
-        Output buffer.
+        Output buffer, to be filled with sorted boxes.
+
+    box_indices : Buffer
+        A indices tensor mapping sorted indices to original indices
+        This is the first output of NMS when return_indices=True.
+
+    num_valid_boxes : Buffer
+        Record the number of boxes that have survived IOU tests.
+        This is the second output of NMS when return_indices=True.
 
     max_output_size : int
         Max number of output valid boxes for each instance.
@@ -509,6 +458,7 @@ def nms_ir(
     sorted_index = ib.buffer_ptr(sorted_index)
     valid_count = ib.buffer_ptr(valid_count)
     indices = ib.buffer_ptr(indices)
+    num_valid_boxes = ib.buffer_ptr(num_valid_boxes)
     out = ib.buffer_ptr(out)
     box_indices = ib.buffer_ptr(box_indices)
 
@@ -523,9 +473,15 @@ def nms_ir(
     max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
 
     with ib.new_scope():
+        nthread_tx = max_threads
+        nthread_bx = ceil_div(num_anchors, max_threads)
         nthread_by = batch_size
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
         by = te.thread_axis("blockIdx.y")
         ib.scope_attr(by, "thread_extent", nthread_by)
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
         i = by
         base_idx = i * num_anchors * box_data_length
         with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
@@ -533,122 +489,95 @@ def nms_ir(
             nkeep = if_then_else(
                 tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, 
valid_count[i]
             )
-            with ib.for_range(0, nkeep) as j:
+            j = bx * max_threads + tx
+            with ib.if_scope(j < num_anchors):
+                box_indices[i * num_anchors + j] = -1
+            with ib.if_scope(j < nkeep):
+                # Fill in out with sorted boxes
                 with ib.for_range(0, box_data_length) as k:
                     out[(base_idx + j * box_data_length + k)] = data[
                         (base_idx + sorted_index[i * num_anchors + j] * 
box_data_length + k)
                     ]
-                box_indices[i * num_anchors + j] = sorted_index[i * 
num_anchors + j]
-            with ib.if_scope(tvm.tir.all(top_k > 0, top_k < valid_count[i])):
-                with ib.for_range(0, valid_count[i] - nkeep) as j:
+            with ib.else_scope():
+                # Indices > nkeep are discarded
+                with ib.if_scope(j < num_anchors):
                     with ib.for_range(0, box_data_length) as k:
-                        out[(base_idx + (j + nkeep) * box_data_length + k)] = 
-1.0
-                    box_indices[i * num_anchors + (j + nkeep)] = -1
+                        out[(base_idx + j * box_data_length + k)] = -1.0
+        with ib.else_scope():
+            with ib.if_scope(j < valid_count[i]):
+                with ib.for_range(0, box_data_length) as k:
+                    offset = base_idx + j * box_data_length + k
+                    out[offset] = data[offset]
+                box_indices[i * num_anchors + j] = j
+
     with ib.new_scope():
         nthread_by = batch_size
         by = te.thread_axis("blockIdx.y")
         ib.scope_attr(by, "thread_extent", nthread_by)
         i = by
         base_idx = i * num_anchors * box_data_length
+        num_valid_boxes_local = ib.allocate(
+            "int32", (1,), name="num_valid_boxes_local", scope="local"
+        )
+        num_valid_boxes_local[0] = 0
+
+        def nms_inner_loop(ib, j):
+            offset_j = j * box_data_length
+
+            with ib.for_range(0, j) as k:
+                offset_k = k * box_data_length
+
+                with ib.if_scope(
+                    tvm.tir.all(
+                        out[base_idx + offset_j + score_index] > -1.0,  # if 
already surpressed
+                        out[base_idx + offset_k + score_index] > 0,
+                        tvm.tir.any(id_index < 0, out[base_idx + offset_k + 
id_index] >= 0),
+                        tvm.tir.any(
+                            force_suppress > 0,
+                            id_index < 0,
+                            out[base_idx + offset_k + id_index]
+                            == out[base_idx + offset_j + id_index],
+                        ),
+                    )
+                ):
+                    iou = calculate_overlap(
+                        out,
+                        base_idx + offset_j + coord_start,
+                        base_idx + offset_k + coord_start,
+                    )
+                    with ib.if_scope(iou >= iou_threshold):
+                        out[base_idx + offset_j + score_index] = -1.0
+                        with ib.if_scope(id_index >= 0):
+                            out[base_idx + offset_j + id_index] = -1.0
+
+            # Has the box j survived IOU tests?
+            with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0):
+                # When return_indices is False, no need to populate box_indices
+                if return_indices:
+                    orig_idx = sorted_index[i * num_anchors + j]
+                    box_indices[i, num_valid_boxes_local[0]] = indices[i, 
orig_idx]
+                num_valid_boxes_local[0] += 1
+
+        if isinstance(max_output_size, int):
+            max_output_size = tvm.tir.const(max_output_size)
+
         with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
             # Apply nms
             with ib.for_range(0, valid_count[i]) as j:
-                with ib.for_range(0, j) as k:
-                    offset_k = k * box_data_length
-                    with ib.if_scope(
-                        tvm.tir.all(
-                            out[base_idx + offset_k + score_index] > 0,
-                            tvm.tir.any(id_index < 0, out[base_idx + offset_k 
+ id_index] >= 0),
-                        )
-                    ):
-                        offset_j = j * box_data_length
-                        with ib.if_scope(
-                            tvm.tir.all(
-                                j > k,
-                                out[base_idx + offset_k + score_index] > 0,
-                                tvm.tir.any(id_index < 0, out[base_idx + 
offset_j + id_index] >= 0),
-                                tvm.tir.any(
-                                    force_suppress > 0,
-                                    id_index < 0,
-                                    out[base_idx + offset_k + id_index]
-                                    == out[base_idx + offset_j + id_index],
-                                ),
-                            )
-                        ):
-                            iou = calculate_overlap(
-                                out,
-                                base_idx + offset_j + coord_start,
-                                base_idx + offset_k + coord_start,
-                            )
-                            with ib.if_scope(iou >= iou_threshold):
-                                out[base_idx + offset_j + score_index] = -1.0
-                                with ib.if_scope(id_index >= 0):
-                                    out[base_idx + offset_j + id_index] = -1.0
-                                box_indices[i * num_anchors + j] = -1
-    with ib.new_scope():
-        nthread_tx = max_threads
-        nthread_bx = num_anchors // max_threads + 1
-        nthread_by = batch_size
-        nthread_bz = box_data_length
-        tx = te.thread_axis("threadIdx.x")
-        bx = te.thread_axis("blockIdx.x")
-        by = te.thread_axis("blockIdx.y")
-        bz = te.thread_axis("blockIdx.z")
-        ib.scope_attr(tx, "thread_extent", nthread_tx)
-        ib.scope_attr(bx, "thread_extent", nthread_bx)
-        ib.scope_attr(by, "thread_extent", nthread_by)
-        ib.scope_attr(bz, "thread_extent", nthread_bz)
-        tid = bx * max_threads + tx
-        i = by
-        j = tid
-        k = bz
-        base_idx = i * num_anchors * box_data_length
-        with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
-            pass
-        with ib.else_scope():
-            with ib.if_scope(j < valid_count[i]):
-                offset_j = j * box_data_length
-                out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
-                box_indices[i * num_anchors + j] = j
-
-    with ib.new_scope():
-        num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", 
scope="local")
-        bx = te.thread_axis("blockIdx.x")
-        ib.scope_attr(bx, "thread_extent", batch_size)
-        i = bx
-        base_idx = i * num_anchors * box_data_length
-        # Set invalid entry to be -1
-        with ib.for_range(0, num_anchors - valid_count[i]) as j:
-            with ib.for_range(0, box_data_length) as k:
-                out[base_idx + (j + valid_count[i]) * box_data_length + k] = 
-1.0
-            box_indices[i * num_anchors + j + valid_count[i]] = -1
-        # Only return max_output_size number of valid boxes
-        num_valid_boxes[0] = 0
-        with ib.if_scope(max_output_size > 0):
-            with ib.for_range(0, valid_count[i]) as j:
-                offset_j = j * box_data_length
-                with ib.if_scope(out[base_idx + offset_j] >= 0):
-                    with ib.if_scope(num_valid_boxes[0] == max_output_size):
-                        with ib.for_range(0, box_data_length) as k:
-                            out[base_idx + offset_j + k] = -1.0
-                        box_indices[i * num_anchors + j] = -1
+                with ib.if_scope(
+                    tvm.tir.any(id_index < 0, out[base_idx + j * 
box_data_length + id_index] >= 0)
+                ):
+                    with ib.if_scope(max_output_size > 0):
+                        # No need to do more iteration if we already reach 
max_output_size boxes
+                        with ib.if_scope(num_valid_boxes_local[0] < 
max_output_size):
+                            nms_inner_loop(ib, j)
                     with ib.else_scope():
-                        num_valid_boxes[0] += 1
+                        nms_inner_loop(ib, j)
 
-    if return_indices:
-        with ib.new_scope():
-            nthread_tx = max_threads
-            nthread_bx = batch_size // max_threads + 1
-            tx = te.thread_axis("threadIdx.x")
-            bx = te.thread_axis("blockIdx.x")
-            ib.scope_attr(tx, "thread_extent", nthread_tx)
-            ib.scope_attr(bx, "thread_extent", nthread_bx)
-            i = bx * max_threads + tx
-            with ib.if_scope(i < batch_size):
-                with ib.for_range(0, valid_count[i]) as j:
-                    idx = box_indices[i * num_anchors + j]
-                    with ib.if_scope(idx >= 0):
-                        box_indices[i * num_anchors + j] = indices[i * 
num_anchors + idx]
+            num_valid_boxes[i] = num_valid_boxes_local[0]
+
+        with ib.else_scope():
+            num_valid_boxes[i] = 0
 
     return ib.get()
 
@@ -816,13 +745,11 @@ def non_max_suppression(
         sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", 
data_alignment=8
     )
 
-    indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, 
"indices_buf", data_alignment=8)
-
     data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", 
data_alignment=8)
     indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, 
"indices_buf", data_alignment=8)
 
-    out, box_indices = te.extern(
-        [data.shape, score_shape],
+    out, box_indices, num_valid_boxes = te.extern(
+        [data.shape, score_shape, [batch_size, 1]],
         [data, sort_tensor, valid_count, indices],
         lambda ins, outs: nms_ir(
             ins[0],
@@ -831,6 +758,7 @@ def non_max_suppression(
             ins[3],
             outs[0],
             outs[1],
+            outs[2],
             max_output_size,
             iou_threshold,
             force_suppress,
@@ -840,24 +768,13 @@ def non_max_suppression(
             score_index,
             return_indices,
         ),
-        dtype=[data.dtype, "int32"],
+        dtype=[data.dtype, "int32", "int32"],
         in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf],
         name="nms",
         tag="nms",
     )
+
     if return_indices:
-        out_shape = box_indices.shape
-        valid_box_count_shape = [box_indices.shape[0], 1]
-        valid_box_count = tvm.tir.decl_buffer(valid_box_count_shape, "int32", 
"valid_box_count")
-        output = tvm.tir.decl_buffer(box_indices.shape, "int32", "output")
-        return te.extern(
-            [out_shape, valid_box_count_shape],
-            [box_indices],
-            lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0], 
outs[1]),
-            dtype="int32",
-            out_buffers=[output, valid_box_count],
-            name="rearrange_indices_out_gpu",
-            tag="rearrange_indices_out_gpu",
-        )
+        return [box_indices, num_valid_boxes]
 
     return out

Reply via email to