This is an automated email from the ASF dual-hosted git repository.
laurawly pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 82942fb [TOPI] Simplify GPU NMS IR and optimize a bit (#7136)
82942fb is described below
commit 82942fb33fd6e3572897c815af16905c4f75c2a4
Author: masahi <[email protected]>
AuthorDate: Tue Dec 22 02:17:26 2020 +0900
[TOPI] Simplify GPU NMS IR and optimize a bit (#7136)
* remove get_valid_counts from pytorch nms
* fix pytorch nms for negative score
* merge reset by -1
* move max_out_size handling to triangle loop
* update torch nms test
* fuse the last two kernels
* parallelize the first kernel
* merge first and last kernel
* remove unnecessary cases
* fix typo
* revert pytorch frontend change
* fuse rearrange step with triangle loop
* fix max_output_size handling
* check if already surpressed
* fix topi vision test by wrapping tir const around int argument
* fix for num anchors = 0 case
* fix missing zero init of num valid boxes when the input is empty
* add some comments and missing doc
* typo fix
* add a guard against zero dim grid / thread block inside ir_buidlder
* typo fix
* trigger CI
---
python/tvm/tir/ir_builder.py | 4 +
python/tvm/topi/cuda/nms.py | 279 +++++++++++++++----------------------------
2 files changed, 102 insertions(+), 181 deletions(-)
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index 75c5c29..6dcc858 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -21,6 +21,7 @@ from tvm.ir import container as _container, PointerType,
PrimType
from . import stmt as _stmt
from . import expr as _expr
+from . import op
class WithScope(object):
@@ -200,6 +201,9 @@ class IRBuilder(object):
node = _expr.StringImm(node)
if isinstance(value, string_types):
value = _expr.StringImm(value)
+ # thread_extent could be zero for dynamic workloads
+ if attr_key == "thread_extent":
+ value = op.max(1, value)
self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))
def for_range(self, begin, end, name="i", dtype="int32",
for_type="serial"):
diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index 2733970..cea287e 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -51,68 +51,8 @@ def atomic_add(x, y):
return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y)
-def rearrange_indices_out_ir(data, output, valid_box_count):
- """Hybrid routine to rearrange nms output to
- move all valid entries to top.
-
- Parameters
- ----------
- data : tvm.te.Tensor or numpy NDArray
- NMS output. 3-D tensor with shape
- [batch_size, num_anchors, 6] or
- [batch_size, num_anchors, 5], or 2-D
- tensor with shape [batch_size, num_anchors].
-
- one: tvm.tir.const
- Constant one with the same dtype as data.
-
- batch_size: tvm.tir.IntImm or tvm.tir.Var
- Batch size. We need to pass it in since hybrid script doesn't support
- binding variable to symbolic dim.
-
- num_anchors: tvm.tir.IntImm or tvm.tir.Var
- Number of anchors.
-
- Returns
- -------
- output : tvm.te.Tensor or numpy NDArray
- 2-D tensor with shape [batch_size, num_anchors].
-
- valid_box_count : tvm.te.Tensor or numpy NDArray
- Tensor with shape [batch_size, 1], indicates
- the valid number of boxes.
- """
- batch_size = data.shape[0]
- num_anchors = data.shape[1]
-
- ib = tvm.tir.ir_builder.create()
-
- data = ib.buffer_ptr(data)
- valid_box_count = ib.buffer_ptr(valid_box_count)
- output = ib.buffer_ptr(output)
-
- with ib.new_scope():
- i = te.thread_axis("blockIdx.x")
- ib.scope_attr(i, "thread_extent", batch_size)
- valid_idx = ib.allocate("int32", (1,), name="valid_idx", scope="local")
- valid_idx[0] = 0
- with ib.for_range(0, num_anchors, name="j") as j:
- with ib.if_scope(data[i, j] >= 0):
- with ib.if_scope(data[i, j] > num_anchors):
- output[i, valid_idx[0]] = 0
- valid_idx[0] = valid_idx[0] + 1
- with ib.else_scope():
- output[i, valid_idx[0]] = data[i, j]
- valid_idx[0] = valid_idx[0] + 1
- with ib.else_scope():
- with ib.if_scope(data[i, j] < -num_anchors):
- output[i, valid_idx[0]] = 0
- valid_idx[0] = valid_idx[0] + 1
- with ib.if_scope(j >= valid_idx[0]):
- output[i, j] = -1
- valid_box_count[i, 0] = valid_idx[0]
-
- return ib.get()
+def ceil_div(a, b):
+ return tvm.tir.indexdiv(a + b - 1, b)
def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index,
score_index):
@@ -400,6 +340,7 @@ def nms_ir(
indices,
out,
box_indices,
+ num_valid_boxes,
max_output_size,
iou_threshold,
force_suppress,
@@ -430,7 +371,15 @@ def nms_ir(
is not used before non_max_suppression.
out : Buffer
- Output buffer.
+ Output buffer, to be filled with sorted boxes.
+
+ box_indices : Buffer
+ A indices tensor mapping sorted indices to original indices
+ This is the first output of NMS when return_indices=True.
+
+ num_valid_boxes : Buffer
+ Record the number of boxes that have survived IOU tests.
+ This is the second output of NMS when return_indices=True.
max_output_size : int
Max number of output valid boxes for each instance.
@@ -509,6 +458,7 @@ def nms_ir(
sorted_index = ib.buffer_ptr(sorted_index)
valid_count = ib.buffer_ptr(valid_count)
indices = ib.buffer_ptr(indices)
+ num_valid_boxes = ib.buffer_ptr(num_valid_boxes)
out = ib.buffer_ptr(out)
box_indices = ib.buffer_ptr(box_indices)
@@ -523,9 +473,15 @@ def nms_ir(
max_threads =
int(tvm.target.Target.current(allow_none=False).max_num_threads)
with ib.new_scope():
+ nthread_tx = max_threads
+ nthread_bx = ceil_div(num_anchors, max_threads)
nthread_by = batch_size
+ tx = te.thread_axis("threadIdx.x")
+ bx = te.thread_axis("blockIdx.x")
by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
+ ib.scope_attr(tx, "thread_extent", nthread_tx)
+ ib.scope_attr(bx, "thread_extent", nthread_bx)
i = by
base_idx = i * num_anchors * box_data_length
with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
@@ -533,122 +489,95 @@ def nms_ir(
nkeep = if_then_else(
tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k,
valid_count[i]
)
- with ib.for_range(0, nkeep) as j:
+ j = bx * max_threads + tx
+ with ib.if_scope(j < num_anchors):
+ box_indices[i * num_anchors + j] = -1
+ with ib.if_scope(j < nkeep):
+ # Fill in out with sorted boxes
with ib.for_range(0, box_data_length) as k:
out[(base_idx + j * box_data_length + k)] = data[
(base_idx + sorted_index[i * num_anchors + j] *
box_data_length + k)
]
- box_indices[i * num_anchors + j] = sorted_index[i *
num_anchors + j]
- with ib.if_scope(tvm.tir.all(top_k > 0, top_k < valid_count[i])):
- with ib.for_range(0, valid_count[i] - nkeep) as j:
+ with ib.else_scope():
+ # Indices > nkeep are discarded
+ with ib.if_scope(j < num_anchors):
with ib.for_range(0, box_data_length) as k:
- out[(base_idx + (j + nkeep) * box_data_length + k)] =
-1.0
- box_indices[i * num_anchors + (j + nkeep)] = -1
+ out[(base_idx + j * box_data_length + k)] = -1.0
+ with ib.else_scope():
+ with ib.if_scope(j < valid_count[i]):
+ with ib.for_range(0, box_data_length) as k:
+ offset = base_idx + j * box_data_length + k
+ out[offset] = data[offset]
+ box_indices[i * num_anchors + j] = j
+
with ib.new_scope():
nthread_by = batch_size
by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
i = by
base_idx = i * num_anchors * box_data_length
+ num_valid_boxes_local = ib.allocate(
+ "int32", (1,), name="num_valid_boxes_local", scope="local"
+ )
+ num_valid_boxes_local[0] = 0
+
+ def nms_inner_loop(ib, j):
+ offset_j = j * box_data_length
+
+ with ib.for_range(0, j) as k:
+ offset_k = k * box_data_length
+
+ with ib.if_scope(
+ tvm.tir.all(
+ out[base_idx + offset_j + score_index] > -1.0, # if
already surpressed
+ out[base_idx + offset_k + score_index] > 0,
+ tvm.tir.any(id_index < 0, out[base_idx + offset_k +
id_index] >= 0),
+ tvm.tir.any(
+ force_suppress > 0,
+ id_index < 0,
+ out[base_idx + offset_k + id_index]
+ == out[base_idx + offset_j + id_index],
+ ),
+ )
+ ):
+ iou = calculate_overlap(
+ out,
+ base_idx + offset_j + coord_start,
+ base_idx + offset_k + coord_start,
+ )
+ with ib.if_scope(iou >= iou_threshold):
+ out[base_idx + offset_j + score_index] = -1.0
+ with ib.if_scope(id_index >= 0):
+ out[base_idx + offset_j + id_index] = -1.0
+
+ # Has the box j survived IOU tests?
+ with ib.if_scope(out[base_idx + offset_j + score_index] > -1.0):
+ # When return_indices is False, no need to populate box_indices
+ if return_indices:
+ orig_idx = sorted_index[i * num_anchors + j]
+ box_indices[i, num_valid_boxes_local[0]] = indices[i,
orig_idx]
+ num_valid_boxes_local[0] += 1
+
+ if isinstance(max_output_size, int):
+ max_output_size = tvm.tir.const(max_output_size)
+
with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
# Apply nms
with ib.for_range(0, valid_count[i]) as j:
- with ib.for_range(0, j) as k:
- offset_k = k * box_data_length
- with ib.if_scope(
- tvm.tir.all(
- out[base_idx + offset_k + score_index] > 0,
- tvm.tir.any(id_index < 0, out[base_idx + offset_k
+ id_index] >= 0),
- )
- ):
- offset_j = j * box_data_length
- with ib.if_scope(
- tvm.tir.all(
- j > k,
- out[base_idx + offset_k + score_index] > 0,
- tvm.tir.any(id_index < 0, out[base_idx +
offset_j + id_index] >= 0),
- tvm.tir.any(
- force_suppress > 0,
- id_index < 0,
- out[base_idx + offset_k + id_index]
- == out[base_idx + offset_j + id_index],
- ),
- )
- ):
- iou = calculate_overlap(
- out,
- base_idx + offset_j + coord_start,
- base_idx + offset_k + coord_start,
- )
- with ib.if_scope(iou >= iou_threshold):
- out[base_idx + offset_j + score_index] = -1.0
- with ib.if_scope(id_index >= 0):
- out[base_idx + offset_j + id_index] = -1.0
- box_indices[i * num_anchors + j] = -1
- with ib.new_scope():
- nthread_tx = max_threads
- nthread_bx = num_anchors // max_threads + 1
- nthread_by = batch_size
- nthread_bz = box_data_length
- tx = te.thread_axis("threadIdx.x")
- bx = te.thread_axis("blockIdx.x")
- by = te.thread_axis("blockIdx.y")
- bz = te.thread_axis("blockIdx.z")
- ib.scope_attr(tx, "thread_extent", nthread_tx)
- ib.scope_attr(bx, "thread_extent", nthread_bx)
- ib.scope_attr(by, "thread_extent", nthread_by)
- ib.scope_attr(bz, "thread_extent", nthread_bz)
- tid = bx * max_threads + tx
- i = by
- j = tid
- k = bz
- base_idx = i * num_anchors * box_data_length
- with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
- pass
- with ib.else_scope():
- with ib.if_scope(j < valid_count[i]):
- offset_j = j * box_data_length
- out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
- box_indices[i * num_anchors + j] = j
-
- with ib.new_scope():
- num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes",
scope="local")
- bx = te.thread_axis("blockIdx.x")
- ib.scope_attr(bx, "thread_extent", batch_size)
- i = bx
- base_idx = i * num_anchors * box_data_length
- # Set invalid entry to be -1
- with ib.for_range(0, num_anchors - valid_count[i]) as j:
- with ib.for_range(0, box_data_length) as k:
- out[base_idx + (j + valid_count[i]) * box_data_length + k] =
-1.0
- box_indices[i * num_anchors + j + valid_count[i]] = -1
- # Only return max_output_size number of valid boxes
- num_valid_boxes[0] = 0
- with ib.if_scope(max_output_size > 0):
- with ib.for_range(0, valid_count[i]) as j:
- offset_j = j * box_data_length
- with ib.if_scope(out[base_idx + offset_j] >= 0):
- with ib.if_scope(num_valid_boxes[0] == max_output_size):
- with ib.for_range(0, box_data_length) as k:
- out[base_idx + offset_j + k] = -1.0
- box_indices[i * num_anchors + j] = -1
+ with ib.if_scope(
+ tvm.tir.any(id_index < 0, out[base_idx + j *
box_data_length + id_index] >= 0)
+ ):
+ with ib.if_scope(max_output_size > 0):
+ # No need to do more iteration if we already reach
max_output_size boxes
+ with ib.if_scope(num_valid_boxes_local[0] <
max_output_size):
+ nms_inner_loop(ib, j)
with ib.else_scope():
- num_valid_boxes[0] += 1
+ nms_inner_loop(ib, j)
- if return_indices:
- with ib.new_scope():
- nthread_tx = max_threads
- nthread_bx = batch_size // max_threads + 1
- tx = te.thread_axis("threadIdx.x")
- bx = te.thread_axis("blockIdx.x")
- ib.scope_attr(tx, "thread_extent", nthread_tx)
- ib.scope_attr(bx, "thread_extent", nthread_bx)
- i = bx * max_threads + tx
- with ib.if_scope(i < batch_size):
- with ib.for_range(0, valid_count[i]) as j:
- idx = box_indices[i * num_anchors + j]
- with ib.if_scope(idx >= 0):
- box_indices[i * num_anchors + j] = indices[i *
num_anchors + idx]
+ num_valid_boxes[i] = num_valid_boxes_local[0]
+
+ with ib.else_scope():
+ num_valid_boxes[i] = 0
return ib.get()
@@ -816,13 +745,11 @@ def non_max_suppression(
sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf",
data_alignment=8
)
- indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype,
"indices_buf", data_alignment=8)
-
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf",
data_alignment=8)
indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype,
"indices_buf", data_alignment=8)
- out, box_indices = te.extern(
- [data.shape, score_shape],
+ out, box_indices, num_valid_boxes = te.extern(
+ [data.shape, score_shape, [batch_size, 1]],
[data, sort_tensor, valid_count, indices],
lambda ins, outs: nms_ir(
ins[0],
@@ -831,6 +758,7 @@ def non_max_suppression(
ins[3],
outs[0],
outs[1],
+ outs[2],
max_output_size,
iou_threshold,
force_suppress,
@@ -840,24 +768,13 @@ def non_max_suppression(
score_index,
return_indices,
),
- dtype=[data.dtype, "int32"],
+ dtype=[data.dtype, "int32", "int32"],
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf, indices_buf],
name="nms",
tag="nms",
)
+
if return_indices:
- out_shape = box_indices.shape
- valid_box_count_shape = [box_indices.shape[0], 1]
- valid_box_count = tvm.tir.decl_buffer(valid_box_count_shape, "int32",
"valid_box_count")
- output = tvm.tir.decl_buffer(box_indices.shape, "int32", "output")
- return te.extern(
- [out_shape, valid_box_count_shape],
- [box_indices],
- lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0],
outs[1]),
- dtype="int32",
- out_buffers=[output, valid_box_count],
- name="rearrange_indices_out_gpu",
- tag="rearrange_indices_out_gpu",
- )
+ return [box_indices, num_valid_boxes]
return out