This is an automated email from the ASF dual-hosted git repository.
kevinthesun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3d9ae3e [Relay][Topi] Fix GPU NMS when return_indices is True (#7005)
3d9ae3e is described below
commit 3d9ae3ec32e7ef908c96abbcddc8b36dff5e0c95
Author: Yao Wang <[email protected]>
AuthorDate: Sat Dec 5 00:56:46 2020 -0800
[Relay][Topi] Fix GPU NMS when return_indices is True (#7005)
* Add rearrange_indices
* Fix output type
* Clean test
* Fix pylint
* Fix CPU nms multi-batch
* Diable test
* Minor fix
* Minor fix
---
python/tvm/topi/cuda/nms.py | 93 ++++++++++++++++++++++++++--
python/tvm/topi/vision/nms.py | 24 +++----
tests/python/relay/test_any.py | 91 +++++++++++++++++++++++----
tests/python/relay/test_op_level5.py | 4 +-
tests/python/topi/python/test_topi_vision.py | 2 +-
5 files changed, 184 insertions(+), 30 deletions(-)
diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index 689298e..46d7f98 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -54,6 +54,68 @@ def atomic_add(x, y):
return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y)
+def rearrange_indices_out_ir(data, out, valid_box_count):
+ """Hybrid routine to rearrange nms output to
+ move all valid entries to top.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor or numpy NDArray
+ tensor with shape [batch_size, num_anchors].
+
+
+ Returns
+ -------
+ stmt : Stmt
+ The result IR statement.
+ """
+ batch_size = data.shape[0]
+ num_anchors = data.shape[1]
+
+ ib = tvm.tir.ir_builder.create()
+ data = ib.buffer_ptr(data)
+ out = ib.buffer_ptr(out)
+ valid_box_count = ib.buffer_ptr(valid_box_count)
+
+ one_count = tvm.tir.const(1, dtype="int32")
+ atomic_add_return = ib.allocate(
+ valid_box_count.dtype, (batch_size,), name="atomic_add_return",
scope="local"
+ )
+
+ max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+ nthread_tx = max_threads
+ tx = te.thread_axis("threadIdx.x")
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ len_inner_for = (batch_size * num_anchors) // nthread_tx + 2
+
+ idxd = tvm.tir.indexdiv
+ idxm = tvm.tir.indexmod
+
+ with ib.for_range(0, len_inner_for, name="i") as i:
+ idx = tx * len_inner_for + i
+ batch_idx = idxd(idx, num_anchors)
+ with ib.if_scope(idx < batch_size):
+ valid_box_count[idx] = 0
+ with ib.if_scope(idx < batch_size * num_anchors):
+ with ib.if_scope(data[idx] >= 0):
+ atomic_add_return[batch_idx] = atomic_add(
+ tvm.tir.call_intrin("handle", "tir.address_of",
valid_box_count[batch_idx]),
+ one_count,
+ )
+ out[batch_idx * num_anchors + atomic_add_return[batch_idx]] =
data[idx]
+ with ib.if_scope(tvm.tir.any(data[idx] > num_anchors, data[idx] <
-num_anchors)):
+ atomic_add_return[batch_idx] = atomic_add(
+ tvm.tir.call_intrin("handle", "tir.address_of",
valid_box_count[batch_idx]),
+ one_count,
+ )
+ out[batch_idx * num_anchors + atomic_add_return[batch_idx]] = 0
+
+ with ib.if_scope(idxm(idx, num_anchors) >=
valid_box_count[batch_idx]):
+ out[idx] = -1
+
+ return ib.get()
+
+
def get_valid_counts_ir(
data, valid_count, out, out_indices, score_threshold, id_index, score_index
):
@@ -198,6 +260,7 @@ def nms_ir(
data,
sorted_index,
valid_count,
+ indices,
out,
box_indices,
max_output_size,
@@ -207,6 +270,7 @@ def nms_ir(
coord_start,
id_index,
score_index,
+ return_indices,
):
"""Low level IR routing for transform location in multibox_detection
operator.
@@ -285,6 +349,7 @@ def nms_ir(
valid_count = ib.buffer_ptr(valid_count)
out = ib.buffer_ptr(out)
box_indices = ib.buffer_ptr(box_indices)
+ indices = ib.buffer_ptr(indices)
num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes",
scope="local")
max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
@@ -379,6 +444,12 @@ def nms_ir(
with ib.else_scope():
num_valid_boxes[0] += 1
+ if return_indices:
+ with ib.if_scope(j < valid_count[i]):
+ box_idx = box_indices[i * num_anchors + j]
+ with ib.if_scope(box_idx >= 0):
+ box_indices[i * num_anchors + j] = indices[i * num_anchors
+ box_idx]
+
return ib.get()
@@ -502,14 +573,16 @@ def non_max_suppression(
)
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf",
data_alignment=8)
+ indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype,
"indices_buf", data_alignment=8)
out, box_indices = te.extern(
[data.shape, score_shape],
- [data, sort_tensor, valid_count],
+ [data, sort_tensor, valid_count, indices],
lambda ins, outs: nms_ir(
ins[0],
ins[1],
ins[2],
+ ins[3],
outs[0],
outs[1],
max_output_size,
@@ -519,14 +592,26 @@ def non_max_suppression(
coord_start,
id_index,
score_index,
+ return_indices,
),
dtype=[data.dtype, "int32"],
- in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
+ in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf],
name="nms",
tag="nms",
)
- # TODO(yongwww): Update cuda nms to be consistent with cpu version
+
if return_indices:
- return box_indices
+ out_buf = tvm.tir.decl_buffer(
+ box_indices.shape, box_indices.dtype, "out_buf", data_alignment=8
+ )
+ return te.extern(
+ [box_indices.shape, (batch_size, 1)],
+ [box_indices],
+ lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0],
outs[1]),
+ dtype=[box_indices.dtype, valid_count.dtype],
+ in_buffers=[out_buf],
+ name="rearrange_indices_out",
+ tag="rearrange_indices_out",
+ )
return out
diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py
index 76e1808..b076fde 100644
--- a/python/tvm/topi/vision/nms.py
+++ b/python/tvm/topi/vision/nms.py
@@ -52,15 +52,16 @@ def hybrid_rearrange_box_out(data, one, batch_size,
num_anchors):
"""
elem_length = data.shape[2]
output = output_tensor((batch_size, num_anchors, elem_length), data.dtype)
+ valid_indices = allocate((batch_size,), "int32")
for i in parallel(batch_size):
- valid_idx = 0
+ valid_indices[i] = 0
for j in range(num_anchors):
if data[i, j, 0] >= 0:
for k in range(elem_length):
- output[i, valid_idx, k] = data[i, j, k]
- valid_idx += 1
- if j >= valid_idx:
+ output[i, valid_indices[i], k] = data[i, j, k]
+ valid_indices[i] += 1
+ if j >= valid_indices[i]:
for k in range(elem_length):
output[i, j, k] = -one
return output
@@ -100,19 +101,20 @@ def hybrid_rearrange_indices_out(data, one, batch_size,
num_anchors):
"""
valid_box_count = output_tensor((batch_size, 1), "int32")
output = output_tensor((batch_size, num_anchors), data.dtype)
+ valid_indices = allocate((batch_size,), "int32")
for i in parallel(batch_size):
- valid_idx = 0
+ valid_indices[i] = 0
for j in range(num_anchors):
if data[i, j] >= 0:
- output[i, valid_idx] = data[i, j]
- valid_idx += 1
+ output[i, valid_indices[i]] = data[i, j]
+ valid_indices[i] += 1
if data[i, j] > num_anchors or data[i, j] < -num_anchors:
- output[i, valid_idx] = 0
- valid_idx += 1
- if j >= valid_idx:
+ output[i, valid_indices[i]] = 0
+ valid_indices[i] += 1
+ if j >= valid_indices[i]:
output[i, j] = -one
- valid_box_count[i, 0] = valid_idx
+ valid_box_count[i, 0] = valid_indices[i]
return output, valid_box_count
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index ddf8e98..da029e1 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -25,6 +25,8 @@ from tvm.relay.testing import run_infer_type as infer_type
from utils.assert_diagnostic import DiagnosticTesting
import tvm.topi.testing
+import os
+
def int32(val):
return relay.const(val, "int32")
@@ -38,27 +40,43 @@ def any_dims(ndim):
def check_result(
- args, mod, expected, flatten=False, assert_shape=False, only_vm=False,
targets=None
+ args,
+ mod,
+ expected,
+ flatten=False,
+ assert_shape=False,
+ only_vm=False,
+ targets=None,
+ disable_targets=None,
):
+ if not isinstance(expected, list):
+ expected = [expected]
for kind in ["debug", "vm"]:
targets = targets or tvm.testing.enabled_targets()
for tgt, ctx in targets:
+ if disable_targets and tgt in disable_targets:
+ continue
if kind == "debug" and (only_vm or ctx.device_type !=
tvm.cpu().device_type):
continue
ex = relay.create_executor(kind, mod=mod, ctx=ctx, target=tgt)
result = ex.evaluate()(*args)
- result = result.asnumpy()
- if assert_shape:
- assert result.shape == expected, "Shape mismatch: expect %s
but got %s." % (
- str(expected),
- str(result.shape),
- )
- return
+ if isinstance(result, tvm.runtime.container.ADT):
+ result = [r.asnumpy() for r in result]
+ else:
+ result = [result.asnumpy()]
- if flatten:
- result = result.flatten()
- expected = expected.flatten()
- tvm.testing.assert_allclose(result, expected, atol=2e-6)
+ for r, e in zip(result, expected):
+ if assert_shape:
+ assert r.shape == e, "Shape mismatch: expect %s but got
%s." % (
+ str(e),
+ str(r),
+ )
+ return
+
+ if flatten:
+ r = r.flatten()
+ e = e.flatten()
+ tvm.testing.assert_allclose(r, e, atol=2e-6)
def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
@@ -1364,5 +1382,54 @@ def test_any_where():
)
+# TODO(kevinthesun): enable gpu test when Thrust is available in ci.
+# @tvm.testing.uses_gpu
+def test_non_max_suppression():
+ x0 = relay.var("x0", relay.ty.TensorType((1, relay.Any(), 6), "float32"))
+ x1 = relay.var("x1", relay.ty.TensorType((1,), "int32"))
+ x2 = relay.var("x2", relay.ty.TensorType((1, relay.Any()), "int32"))
+ x3 = relay.var("x3", relay.ty.TensorType((), "int32"))
+ z = relay.vision.non_max_suppression(
+ x0,
+ x1,
+ x2,
+ x3,
+ iou_threshold=0.5,
+ force_suppress=True,
+ top_k=2,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+ z = z.astuple()
+ func = relay.Function([x0, x1, x2, x3], z)
+ mod = tvm.IRModule()
+ mod["main"] = func
+
+ np_data = np.array(
+ [
+ [
+ [0, 0.8, 1, 20, 25, 45],
+ [1, 0.7, 30, 60, 50, 80],
+ [0, 0.4, 4, 21, 19, 40],
+ [2, 0.9, 35, 61, 52, 79],
+ [1, 0.5, 100, 60, 70, 110],
+ ]
+ ]
+ ).astype("float32")
+ np_valid_count = np.array([4]).astype("int32")
+ np_indices = np.array([[0, 1, 3, 4, -1]]).astype("int32")
+ np_max_output_size = -1
+ np_indices_result = np.array([[4, 0, -1, -1, -1]])
+ np_valid_box_count = np.array([[2]]).astype("int32")
+
+ check_result(
+ [np_data, np_valid_count, np_indices, np_max_output_size],
+ mod,
+ [np_indices_result, np_valid_box_count],
+ only_vm=False,
+ disable_targets=["nvptx"],
+ )
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/tests/python/relay/test_op_level5.py
b/tests/python/relay/test_op_level5.py
index 5a5a12c..9e9aaf8 100644
--- a/tests/python/relay/test_op_level5.py
+++ b/tests/python/relay/test_op_level5.py
@@ -393,8 +393,8 @@ def test_non_max_suppression():
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data, x3_data)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
- if target == "cuda":
- return
+ if target == "nvptx":
+ continue
op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data,
x2_data, x3_data)
tvm.testing.assert_allclose(op_indices_res1[0].asnumpy(),
ref_indices_res, rtol=1e-5)
op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data,
x2_data, x3_data)
diff --git a/tests/python/topi/python/test_topi_vision.py
b/tests/python/topi/python/test_topi_vision.py
index 22c9045..6d6353e 100644
--- a/tests/python/topi/python/test_topi_vision.py
+++ b/tests/python/topi/python/test_topi_vision.py
@@ -202,7 +202,7 @@ def verify_non_max_suppression(
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)
tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape,
dtype="int32"), ctx)
- if device == "llvm":
+ if device in ["llvm", "cuda"]:
f = tvm.build(indices_s, [data, valid_count, indices,
indices_out[0]], device)
f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out)
else: