This is an automated email from the ASF dual-hosted git repository.

kevinthesun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bad149e  [TOPI] Fix GPU Dynamic Op Schedule (#7117)
bad149e is described below

commit bad149ed8a555444d813537608ee5cea9e95e97e
Author: Yao Wang <[email protected]>
AuthorDate: Thu Dec 17 13:54:44 2020 -0800

    [TOPI] Fix GPU Dynamic Op Schedule (#7117)
    
    * Fix GPU dynamic op schedules
    
    * Fix dynamic shape nms
    
    * Fix
    
    * Fix test format
---
 python/tvm/topi/cuda/conv2d_transpose_nchw.py |  7 +++-
 python/tvm/topi/cuda/injective.py             | 13 ++++++-
 python/tvm/topi/cuda/nms.py                   | 49 +++++++++++++++++++++++++--
 python/tvm/topi/cuda/sort.py                  |  3 ++
 src/runtime/contrib/thrust/thrust.cu          |  9 +++++
 src/runtime/vm/vm.cc                          | 17 ++++++++--
 tests/python/relay/test_any.py                | 29 ++++++++++++++--
 7 files changed, 117 insertions(+), 10 deletions(-)

diff --git a/python/tvm/topi/cuda/conv2d_transpose_nchw.py 
b/python/tvm/topi/cuda/conv2d_transpose_nchw.py
index 609d1ac..3b70417 100644
--- a/python/tvm/topi/cuda/conv2d_transpose_nchw.py
+++ b/python/tvm/topi/cuda/conv2d_transpose_nchw.py
@@ -179,7 +179,10 @@ def schedule_conv2d_transpose_nchw(cfg, outs):
             ##### space definition begin #####
             n, f, y, x = s[conv].op.axis
             rc = s[conv].op.reduce_axis[0]
-            cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
+            # TODO(@kevinthesun): Support tuning/optimization for dynamic 
shape.
+            bs = pad_data.shape[0]
+            n_tuning_axis = n if isinstance(bs, tvm.tir.IntImm) else 1
+            cfg.define_split("tile_n", cfg.axis(n_tuning_axis), num_outputs=4)
             cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
             cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
             cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
@@ -194,6 +197,8 @@ def schedule_conv2d_transpose_nchw(cfg, outs):
 
             if cfg.is_fallback:
                 N, F, Y, X = get_const_tuple(conv.shape)
+                if not isinstance(N, int):
+                    N = 1
                 _fallback_schedule(N, F, Y, X)
 
             ##### space definition end #####
diff --git a/python/tvm/topi/cuda/injective.py 
b/python/tvm/topi/cuda/injective.py
index 60fb12e..7f0790a 100644
--- a/python/tvm/topi/cuda/injective.py
+++ b/python/tvm/topi/cuda/injective.py
@@ -44,8 +44,16 @@ def schedule_injective_from_existing(sch, out):
     # bandwidth.
     vector_width = 4 if out.dtype == "float16" else 1
 
+    is_dynamic_output = False
+    for dim in out.shape:
+        if not isinstance(dim, tvm.tir.IntImm):
+            is_dynamic_output = True
+            break
+
+    out_len = utils.prod(out.shape)
+
     try:
-        const_size = utils.get_const_int(utils.prod(out.shape))
+        const_size = utils.get_const_int(out_len)
         need_block_split = const_size > max_block * num_thread * vector_width
     except ValueError:
         need_block_split = False
@@ -61,6 +69,9 @@ def schedule_injective_from_existing(sch, out):
         sch[out].bind(bx, te.thread_axis("blockIdx.x"))
         sch[out].bind(tx, te.thread_axis("threadIdx.x"))
     else:
+        # Use less threads for dynamic shape ops to avoid runtime error.
+        if is_dynamic_output:
+            num_thread //= 2
         bx, tx = sch[out].split(fused, factor=num_thread)
         sch[out].bind(tx, te.thread_axis("threadIdx.x"))
         sch[out].bind(bx, te.thread_axis("blockIdx.x"))
diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index d0915d9..2733970 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -22,7 +22,6 @@ from tvm import te
 
 from tvm.tir import if_then_else
 from .sort import argsort, argsort_thrust
-from .. import tag
 
 
 def cuda_atomic_add_rule(op):
@@ -95,7 +94,7 @@ def rearrange_indices_out_ir(data, output, valid_box_count):
     with ib.new_scope():
         i = te.thread_axis("blockIdx.x")
         ib.scope_attr(i, "thread_extent", batch_size)
-        valid_idx = ib.allocate("int32", (1), name="valid_idx", scope="local")
+        valid_idx = ib.allocate("int32", (1,), name="valid_idx", scope="local")
         valid_idx[0] = 0
         with ib.for_range(0, num_anchors, name="j") as j:
             with ib.if_scope(data[i, j] >= 0):
@@ -654,6 +653,35 @@ def nms_ir(
     return ib.get()
 
 
+def _fetch_score_ir(data, score, axis):
+    """
+    Fetch score from data.
+    This routine is required for dynamic shape nms.
+    """
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+    elem_length = data.shape[2]
+
+    ib = tvm.tir.ir_builder.create()
+
+    data = ib.buffer_ptr(data)
+    score = ib.buffer_ptr(score)
+    with ib.if_scope(num_anchors > 0):
+        max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+        nthread_tx = max_threads
+        nthread_bx = batch_size * num_anchors // max_threads + 1
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+
+        tid = bx * max_threads + tx
+        with ib.if_scope(tid < batch_size * num_anchors):
+            score[tid] = data[tid * elem_length + axis]
+
+    return ib.get()
+
+
 def non_max_suppression(
     data,
     valid_count,
@@ -754,7 +782,22 @@ def non_max_suppression(
     )
     score_axis = score_index
     score_shape = (batch_size, num_anchors)
-    score_tensor = te.compute(score_shape, lambda i, j: data[i, j, 
score_axis], tag=tag.ELEMWISE)
+    data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", 
data_alignment=8)
+    score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", 
data_alignment=8)
+    score_tensor = te.extern(
+        [score_shape],
+        [data],
+        lambda ins, outs: _fetch_score_ir(
+            ins[0],
+            outs[0],
+            score_axis,
+        ),
+        dtype=[data.dtype],
+        in_buffers=[data_buf],
+        out_buffers=[score_buf],
+        name="fetch_score",
+        tag="fetch_score",
+    )
     target = tvm.target.Target.current()
     if (
         target
diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py
index 329f0fb..e4e7c53 100644
--- a/python/tvm/topi/cuda/sort.py
+++ b/python/tvm/topi/cuda/sort.py
@@ -565,6 +565,9 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", 
is_ascend=False, dtype="int
         tag="topk_gpu",
     )
 
+    if isinstance(k, tvm.tir.IntImm):
+        k = k.value
+
     if not isinstance(k, int) or k > 0:
         beg = [0] * ndim
         end = data.shape[:-1] + [k if isinstance(k, int) else 
tvm.te.size_var("dim")]
diff --git a/src/runtime/contrib/thrust/thrust.cu 
b/src/runtime/contrib/thrust/thrust.cu
index 8ccefc5..dddbb04 100644
--- a/src/runtime/contrib/thrust/thrust.cu
+++ b/src/runtime/contrib/thrust/thrust.cu
@@ -205,6 +205,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
     if (value_dtype == "int32") {
       thrust_stable_sort_by_key<int, int>(keys_in, values_in, keys_out, 
values_out,
                                           for_scatter);
+    } else if (value_dtype == "int64") {
+      thrust_stable_sort_by_key<int, int64_t>(keys_in, values_in, keys_out, 
values_out,
+                                              for_scatter);
     } else if (value_dtype == "float32") {
       thrust_stable_sort_by_key<int, float>(keys_in, values_in, keys_out, 
values_out,
                                             for_scatter);
@@ -215,6 +218,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
     if (value_dtype == "int32") {
       thrust_stable_sort_by_key<int64_t, int>(keys_in, values_in, keys_out, 
values_out,
                                               for_scatter);
+    } else if (value_dtype == "int64") {
+      thrust_stable_sort_by_key<int64_t, int64_t>(keys_in, values_in, 
keys_out, values_out,
+                                                  for_scatter);
     } else if (value_dtype == "float32") {
       thrust_stable_sort_by_key<int64_t, float>(keys_in, values_in, keys_out, 
values_out,
                                                 for_scatter);
@@ -225,6 +231,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
     if (value_dtype == "int32") {
       thrust_stable_sort_by_key<float, int>(keys_in, values_in, keys_out, 
values_out,
                                             for_scatter);
+    } else if (value_dtype == "int64") {
+      thrust_stable_sort_by_key<float, int64_t>(keys_in, values_in, keys_out, 
values_out,
+                                              for_scatter);
     } else if (value_dtype == "float32") {
       thrust_stable_sort_by_key<float, float>(keys_in, values_in, keys_out, 
values_out,
                                               for_scatter);
diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index 473b5d7..3f890ba 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -245,6 +245,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const 
PackedFunc& func, In
   std::vector<int> codes(arity);
   runtime::TVMArgsSetter setter(values.data(), codes.data());
   int idx = 0;
+  bool is_empty_output = false;
   for (Index i = 0; i < arg_count; i++) {
     if (const auto* dt_cell = args[i].as<ADTObj>()) {
       for (size_t fi = 0; fi < dt_cell->size; ++fi) {
@@ -254,12 +255,24 @@ void VirtualMachine::InvokePacked(Index packed_index, 
const PackedFunc& func, In
       }
     } else {
       auto nd_array = Downcast<NDArray>(args[i]);
+      // We can safely skip CallPacked if there is only one
+      // output and it is empty.
+      if (i == arg_count - 1 && output_size == 1) {
+        for (const auto& dim : nd_array.Shape()) {
+          if (!dim) {
+            is_empty_output = true;
+            break;
+          }
+        }
+      }
       setter(idx++, nd_array);
     }
   }
 
-  TVMRetValue rv;
-  func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
+  if (!is_empty_output) {
+    TVMRetValue rv;
+    func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
+  }
 }
 
 void VirtualMachine::LoadExecutable(const Executable* exec) {
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index da029e1..dfc03c0 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -54,6 +54,7 @@ def check_result(
     for kind in ["debug", "vm"]:
         targets = targets or tvm.testing.enabled_targets()
         for tgt, ctx in targets:
+            print(tgt)
             if disable_targets and tgt in disable_targets:
                 continue
             if kind == "debug" and (only_vm or ctx.device_type != 
tvm.cpu().device_type):
@@ -199,6 +200,15 @@ def test_any_concat():
     ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0)
     check_result([x_np, y_np], mod, ref)
 
+    num_inputs = 25
+    x = [relay.var("x", shape=(relay.Any(),), dtype="float32") for _ in 
range(num_inputs)]
+    z = relay.op.concatenate(x, axis=0)
+    mod = tvm.IRModule()
+    mod["main"] = relay.Function(x, z)
+    x_np = [np.random.uniform(size=(1,)).astype("float32") for _ in 
range(num_inputs)]
+    ref = np.concatenate(x_np, axis=0)
+    check_result(x_np, mod, ref)
+
 
 def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, 
variable_newshape=False):
     x = relay.var("x", shape=x_shape, dtype="float32")
@@ -572,9 +582,7 @@ def verify_any_conv2d_transpose_nchw(
     mod["main"] = relay.Function([data, kernel], y)
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     kernel_np = np.random.uniform(size=kernel_shape).astype(dtype)
-    check_result(
-        [data_np, kernel_np], mod, ref_out_shape, assert_shape=True, 
targets=[("llvm", tvm.cpu())]
-    )
+    check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True)
 
 
 # TODO(@kevinthesun): Support dynamic input height and width.
@@ -1430,6 +1438,21 @@ def test_non_max_suppression():
         disable_targets=["nvptx"],
     )
 
+    np_data = np.zeros((1, 0, 6)).astype("float32")
+    np_valid_count = np.array([0]).astype("int32")
+    np_indices = np.zeros((1, 0)).astype("int32")
+    np_max_output_size = -1
+    np_indices_result = np.zeros((1, 0))
+    np_valid_box_count = np.array([[0]]).astype("int32")
+
+    check_result(
+        [np_data, np_valid_count, np_indices, np_max_output_size],
+        mod,
+        [np_indices_result, np_valid_box_count],
+        only_vm=False,
+        disable_targets=["nvptx"],
+    )
+
 
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to