This is an automated email from the ASF dual-hosted git repository.

kevinthesun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3d9ae3e  [Relay][Topi] Fix GPU NMS when return_indices is True (#7005)
3d9ae3e is described below

commit 3d9ae3ec32e7ef908c96abbcddc8b36dff5e0c95
Author: Yao Wang <[email protected]>
AuthorDate: Sat Dec 5 00:56:46 2020 -0800

    [Relay][Topi] Fix GPU NMS when return_indices is True (#7005)
    
    * Add rearrange_indices
    
    * Fix output type
    
    * Clean test
    
    * Fix pylint
    
    * Fix CPU nms multi-batch
    
    * Diable test
    
    * Minor fix
    
    * Minor fix
---
 python/tvm/topi/cuda/nms.py                  | 93 ++++++++++++++++++++++++++--
 python/tvm/topi/vision/nms.py                | 24 +++----
 tests/python/relay/test_any.py               | 91 +++++++++++++++++++++++----
 tests/python/relay/test_op_level5.py         |  4 +-
 tests/python/topi/python/test_topi_vision.py |  2 +-
 5 files changed, 184 insertions(+), 30 deletions(-)

diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index 689298e..46d7f98 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -54,6 +54,68 @@ def atomic_add(x, y):
     return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y)
 
 
+def rearrange_indices_out_ir(data, out, valid_box_count):
+    """Hybrid routine to rearrange nms output to
+    move all valid entries to top.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor or numpy NDArray
+        tensor with shape [batch_size, num_anchors].
+
+
+    Returns
+    -------
+    stmt : Stmt
+        The result IR statement.
+    """
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+
+    ib = tvm.tir.ir_builder.create()
+    data = ib.buffer_ptr(data)
+    out = ib.buffer_ptr(out)
+    valid_box_count = ib.buffer_ptr(valid_box_count)
+
+    one_count = tvm.tir.const(1, dtype="int32")
+    atomic_add_return = ib.allocate(
+        valid_box_count.dtype, (batch_size,), name="atomic_add_return", 
scope="local"
+    )
+
+    max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    nthread_tx = max_threads
+    tx = te.thread_axis("threadIdx.x")
+    ib.scope_attr(tx, "thread_extent", nthread_tx)
+    len_inner_for = (batch_size * num_anchors) // nthread_tx + 2
+
+    idxd = tvm.tir.indexdiv
+    idxm = tvm.tir.indexmod
+
+    with ib.for_range(0, len_inner_for, name="i") as i:
+        idx = tx * len_inner_for + i
+        batch_idx = idxd(idx, num_anchors)
+        with ib.if_scope(idx < batch_size):
+            valid_box_count[idx] = 0
+        with ib.if_scope(idx < batch_size * num_anchors):
+            with ib.if_scope(data[idx] >= 0):
+                atomic_add_return[batch_idx] = atomic_add(
+                    tvm.tir.call_intrin("handle", "tir.address_of", 
valid_box_count[batch_idx]),
+                    one_count,
+                )
+                out[batch_idx * num_anchors + atomic_add_return[batch_idx]] = 
data[idx]
+            with ib.if_scope(tvm.tir.any(data[idx] > num_anchors, data[idx] < 
-num_anchors)):
+                atomic_add_return[batch_idx] = atomic_add(
+                    tvm.tir.call_intrin("handle", "tir.address_of", 
valid_box_count[batch_idx]),
+                    one_count,
+                )
+                out[batch_idx * num_anchors + atomic_add_return[batch_idx]] = 0
+
+            with ib.if_scope(idxm(idx, num_anchors) >= 
valid_box_count[batch_idx]):
+                out[idx] = -1
+
+    return ib.get()
+
+
 def get_valid_counts_ir(
     data, valid_count, out, out_indices, score_threshold, id_index, score_index
 ):
@@ -198,6 +260,7 @@ def nms_ir(
     data,
     sorted_index,
     valid_count,
+    indices,
     out,
     box_indices,
     max_output_size,
@@ -207,6 +270,7 @@ def nms_ir(
     coord_start,
     id_index,
     score_index,
+    return_indices,
 ):
     """Low level IR routing for transform location in multibox_detection 
operator.
 
@@ -285,6 +349,7 @@ def nms_ir(
     valid_count = ib.buffer_ptr(valid_count)
     out = ib.buffer_ptr(out)
     box_indices = ib.buffer_ptr(box_indices)
+    indices = ib.buffer_ptr(indices)
     num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", 
scope="local")
 
     max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
@@ -379,6 +444,12 @@ def nms_ir(
                     with ib.else_scope():
                         num_valid_boxes[0] += 1
 
+        if return_indices:
+            with ib.if_scope(j < valid_count[i]):
+                box_idx = box_indices[i * num_anchors + j]
+                with ib.if_scope(box_idx >= 0):
+                    box_indices[i * num_anchors + j] = indices[i * num_anchors 
+ box_idx]
+
     return ib.get()
 
 
@@ -502,14 +573,16 @@ def non_max_suppression(
     )
 
     data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", 
data_alignment=8)
+    indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, 
"indices_buf", data_alignment=8)
 
     out, box_indices = te.extern(
         [data.shape, score_shape],
-        [data, sort_tensor, valid_count],
+        [data, sort_tensor, valid_count, indices],
         lambda ins, outs: nms_ir(
             ins[0],
             ins[1],
             ins[2],
+            ins[3],
             outs[0],
             outs[1],
             max_output_size,
@@ -519,14 +592,26 @@ def non_max_suppression(
             coord_start,
             id_index,
             score_index,
+            return_indices,
         ),
         dtype=[data.dtype, "int32"],
-        in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
+        in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf],
         name="nms",
         tag="nms",
     )
-    # TODO(yongwww): Update cuda nms to be consistent with cpu version
+
     if return_indices:
-        return box_indices
+        out_buf = tvm.tir.decl_buffer(
+            box_indices.shape, box_indices.dtype, "out_buf", data_alignment=8
+        )
+        return te.extern(
+            [box_indices.shape, (batch_size, 1)],
+            [box_indices],
+            lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0], 
outs[1]),
+            dtype=[box_indices.dtype, valid_count.dtype],
+            in_buffers=[out_buf],
+            name="rearrange_indices_out",
+            tag="rearrange_indices_out",
+        )
 
     return out
diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py
index 76e1808..b076fde 100644
--- a/python/tvm/topi/vision/nms.py
+++ b/python/tvm/topi/vision/nms.py
@@ -52,15 +52,16 @@ def hybrid_rearrange_box_out(data, one, batch_size, 
num_anchors):
     """
     elem_length = data.shape[2]
     output = output_tensor((batch_size, num_anchors, elem_length), data.dtype)
+    valid_indices = allocate((batch_size,), "int32")
 
     for i in parallel(batch_size):
-        valid_idx = 0
+        valid_indices[i] = 0
         for j in range(num_anchors):
             if data[i, j, 0] >= 0:
                 for k in range(elem_length):
-                    output[i, valid_idx, k] = data[i, j, k]
-                valid_idx += 1
-            if j >= valid_idx:
+                    output[i, valid_indices[i], k] = data[i, j, k]
+                valid_indices[i] += 1
+            if j >= valid_indices[i]:
                 for k in range(elem_length):
                     output[i, j, k] = -one
     return output
@@ -100,19 +101,20 @@ def hybrid_rearrange_indices_out(data, one, batch_size, 
num_anchors):
     """
     valid_box_count = output_tensor((batch_size, 1), "int32")
     output = output_tensor((batch_size, num_anchors), data.dtype)
+    valid_indices = allocate((batch_size,), "int32")
 
     for i in parallel(batch_size):
-        valid_idx = 0
+        valid_indices[i] = 0
         for j in range(num_anchors):
             if data[i, j] >= 0:
-                output[i, valid_idx] = data[i, j]
-                valid_idx += 1
+                output[i, valid_indices[i]] = data[i, j]
+                valid_indices[i] += 1
             if data[i, j] > num_anchors or data[i, j] < -num_anchors:
-                output[i, valid_idx] = 0
-                valid_idx += 1
-            if j >= valid_idx:
+                output[i, valid_indices[i]] = 0
+                valid_indices[i] += 1
+            if j >= valid_indices[i]:
                 output[i, j] = -one
-        valid_box_count[i, 0] = valid_idx
+        valid_box_count[i, 0] = valid_indices[i]
 
     return output, valid_box_count
 
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index ddf8e98..da029e1 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -25,6 +25,8 @@ from tvm.relay.testing import run_infer_type as infer_type
 from utils.assert_diagnostic import DiagnosticTesting
 import tvm.topi.testing
 
+import os
+
 
 def int32(val):
     return relay.const(val, "int32")
@@ -38,27 +40,43 @@ def any_dims(ndim):
 
 
 def check_result(
-    args, mod, expected, flatten=False, assert_shape=False, only_vm=False, 
targets=None
+    args,
+    mod,
+    expected,
+    flatten=False,
+    assert_shape=False,
+    only_vm=False,
+    targets=None,
+    disable_targets=None,
 ):
+    if not isinstance(expected, list):
+        expected = [expected]
     for kind in ["debug", "vm"]:
         targets = targets or tvm.testing.enabled_targets()
         for tgt, ctx in targets:
+            if disable_targets and tgt in disable_targets:
+                continue
             if kind == "debug" and (only_vm or ctx.device_type != 
tvm.cpu().device_type):
                 continue
             ex = relay.create_executor(kind, mod=mod, ctx=ctx, target=tgt)
             result = ex.evaluate()(*args)
-            result = result.asnumpy()
-            if assert_shape:
-                assert result.shape == expected, "Shape mismatch: expect %s 
but got %s." % (
-                    str(expected),
-                    str(result.shape),
-                )
-                return
+            if isinstance(result, tvm.runtime.container.ADT):
+                result = [r.asnumpy() for r in result]
+            else:
+                result = [result.asnumpy()]
 
-            if flatten:
-                result = result.flatten()
-                expected = expected.flatten()
-            tvm.testing.assert_allclose(result, expected, atol=2e-6)
+            for r, e in zip(result, expected):
+                if assert_shape:
+                    assert r.shape == e, "Shape mismatch: expect %s but got 
%s." % (
+                        str(e),
+                        str(r),
+                    )
+                    return
+
+                if flatten:
+                    r = r.flatten()
+                    e = e.flatten()
+                tvm.testing.assert_allclose(r, e, atol=2e-6)
 
 
 def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
@@ -1364,5 +1382,54 @@ def test_any_where():
     )
 
 
+# TODO(kevinthesun): enable gpu test when Thrust is available in ci.
+# @tvm.testing.uses_gpu
+def test_non_max_suppression():
+    x0 = relay.var("x0", relay.ty.TensorType((1, relay.Any(), 6), "float32"))
+    x1 = relay.var("x1", relay.ty.TensorType((1,), "int32"))
+    x2 = relay.var("x2", relay.ty.TensorType((1, relay.Any()), "int32"))
+    x3 = relay.var("x3", relay.ty.TensorType((), "int32"))
+    z = relay.vision.non_max_suppression(
+        x0,
+        x1,
+        x2,
+        x3,
+        iou_threshold=0.5,
+        force_suppress=True,
+        top_k=2,
+        return_indices=True,
+        invalid_to_bottom=False,
+    )
+    z = z.astuple()
+    func = relay.Function([x0, x1, x2, x3], z)
+    mod = tvm.IRModule()
+    mod["main"] = func
+
+    np_data = np.array(
+        [
+            [
+                [0, 0.8, 1, 20, 25, 45],
+                [1, 0.7, 30, 60, 50, 80],
+                [0, 0.4, 4, 21, 19, 40],
+                [2, 0.9, 35, 61, 52, 79],
+                [1, 0.5, 100, 60, 70, 110],
+            ]
+        ]
+    ).astype("float32")
+    np_valid_count = np.array([4]).astype("int32")
+    np_indices = np.array([[0, 1, 3, 4, -1]]).astype("int32")
+    np_max_output_size = -1
+    np_indices_result = np.array([[4, 0, -1, -1, -1]])
+    np_valid_box_count = np.array([[2]]).astype("int32")
+
+    check_result(
+        [np_data, np_valid_count, np_indices, np_max_output_size],
+        mod,
+        [np_indices_result, np_valid_box_count],
+        only_vm=False,
+        disable_targets=["nvptx"],
+    )
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/relay/test_op_level5.py 
b/tests/python/relay/test_op_level5.py
index 5a5a12c..9e9aaf8 100644
--- a/tests/python/relay/test_op_level5.py
+++ b/tests/python/relay/test_op_level5.py
@@ -393,8 +393,8 @@ def test_non_max_suppression():
             intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
             op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data, x3_data)
             tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
-            if target == "cuda":
-                return
+            if target == "nvptx":
+                continue
             op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, 
x2_data, x3_data)
             tvm.testing.assert_allclose(op_indices_res1[0].asnumpy(), 
ref_indices_res, rtol=1e-5)
             op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, 
x2_data, x3_data)
diff --git a/tests/python/topi/python/test_topi_vision.py 
b/tests/python/topi/python/test_topi_vision.py
index 22c9045..6d6353e 100644
--- a/tests/python/topi/python/test_topi_vision.py
+++ b/tests/python/topi/python/test_topi_vision.py
@@ -202,7 +202,7 @@ def verify_non_max_suppression(
         tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)
 
         tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, 
dtype="int32"), ctx)
-        if device == "llvm":
+        if device in ["llvm", "cuda"]:
             f = tvm.build(indices_s, [data, valid_count, indices, 
indices_out[0]], device)
             f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out)
         else:

Reply via email to