This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 054466b  [ONNX] NMS in ONNX (#6839)
054466b is described below

commit 054466be14ea64b0bfe75081b311cabe0d40a991
Author: Matthew Brookhart <[email protected]>
AuthorDate: Mon Dec 14 11:54:07 2020 -0700

    [ONNX] NMS in ONNX (#6839)
    
    * NMS partially working on CPU, fails on GPU
    
    * support dynamic iou_threshold
    
    * WIP NMS with while loops
    
    * working nms with dynamic shapes
    
    * add a test with dynamic score_threshold and pass it
    
    * Fix type checking in lambda lift
    
    * ONNX NMS working on GPU, had to remove threading from some kernels
    
    fix lint
    
    fix lambda lift tests
    
    fix unit tests
    
    respond to review comments
    
    fix lint
    
    * better parallelize get_valid_counts
    
    * improve nms parallelization
    
    * respond to cuda/thrust enablement issue
    
    Co-authored-by: Jared Roesch <[email protected]>
---
 include/tvm/relay/attrs/vision.h            |  12 +-
 python/tvm/relay/frontend/onnx.py           | 269 ++++++++++++++
 python/tvm/relay/op/strategy/generic.py     |   8 +-
 python/tvm/relay/op/vision/nms.py           |   8 +-
 python/tvm/topi/cuda/nms.py                 | 533 +++++++++++++++++++---------
 python/tvm/topi/cuda/sort.py                |   8 +-
 python/tvm/topi/vision/nms.py               |  16 +-
 src/relay/backend/vm/compiler.cc            |   1 +
 src/relay/backend/vm/lambda_lift.cc         |  31 +-
 src/relay/op/tensor/transform.h             |  16 +-
 src/relay/op/vision/nms.cc                  |  29 +-
 tests/python/frontend/onnx/test_forward.py  |  99 +++++-
 tests/python/relay/test_op_level5.py        |   4 +-
 tests/python/relay/test_pass_lambda_lift.py |   3 +
 14 files changed, 827 insertions(+), 210 deletions(-)

diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h
index 2b905f5..ca2c4a2 100644
--- a/include/tvm/relay/attrs/vision.h
+++ b/include/tvm/relay/attrs/vision.h
@@ -73,14 +73,12 @@ struct MultiBoxTransformLocAttrs : public 
tvm::AttrsNode<MultiBoxTransformLocAtt
 
 /*! \brief Attributes used in get_valid_counts operator */
 struct GetValidCountsAttrs : public tvm::AttrsNode<GetValidCountsAttrs> {
-  double score_threshold;
+  Optional<FloatImm> score_threshold;
   int id_index;
   int score_index;
 
   TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") {
-    TVM_ATTR_FIELD(score_threshold)
-        .set_default(0.0)
-        .describe("Lower limit of score for valid bounding boxes.");
+    TVM_ATTR_FIELD(score_threshold).describe("Lower limit of score for valid 
bounding boxes.");
     TVM_ATTR_FIELD(id_index).set_default(0).describe("Axis index of id.");
     TVM_ATTR_FIELD(score_index).set_default(1).describe("Index of the 
scores/confidence of boxes.");
   }
@@ -89,7 +87,7 @@ struct GetValidCountsAttrs : public 
tvm::AttrsNode<GetValidCountsAttrs> {
 /*! \brief Attributes used in non_maximum_suppression operator */
 struct NonMaximumSuppressionAttrs : public 
tvm::AttrsNode<NonMaximumSuppressionAttrs> {
   Optional<Integer> max_output_size;
-  double iou_threshold;
+  Optional<FloatImm> iou_threshold;
   bool force_suppress;
   int top_k;
   int coord_start;
@@ -100,9 +98,7 @@ struct NonMaximumSuppressionAttrs : public 
tvm::AttrsNode<NonMaximumSuppressionA
 
   TVM_DECLARE_ATTRS(NonMaximumSuppressionAttrs, 
"relay.attrs.NonMaximumSuppressionAttrs") {
     TVM_ATTR_FIELD(max_output_size).describe("Max number of output valid boxes 
for each instance.");
-    TVM_ATTR_FIELD(iou_threshold)
-        .set_default(0.5)
-        .describe("Non-maximum suppression iou threshold.");
+    TVM_ATTR_FIELD(iou_threshold).describe("Non-maximum suppression iou 
threshold.");
     TVM_ATTR_FIELD(force_suppress)
         .set_default(false)
         .describe("Suppress all detections regardless of class_id.");
diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index f0d7e2d..23102aa 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -2303,6 +2303,274 @@ class If(OnnxOpConverter):
         return _expr.If(cond, then_expr, else_expr)
 
 
+class NonMaxSuppression(OnnxOpConverter):
+    """Operator converter for NonMaxSuppression."""
+
+    @classmethod
+    def _impl_v10(cls, inputs, attr, params):
+        """
+        High level note: ONNX implements what TF calls 
combined_non_max_suppression
+        It passes in scores for each box for every class in the output and 
expects boxes to be
+        analyzed for each class independently
+
+        It also asks for the data to be returned in a particular format.
+
+        To support these, we implement a series of lops:
+        The first loop splits over class number, performs NMS, and collects 
the outputs.
+        The second (nested) loop takes the outputs and transforms them into 
the format ONNX wants
+        """
+        # Get parameter values
+        boxes = inputs[0]
+        scores = inputs[1]
+        max_output_boxes_per_class = inputs[2]
+        iou_threshold = inputs[3]
+        score_threshold = inputs[4]
+
+        dtype = infer_type(boxes).checked_type.dtype
+
+        if "center_point_box" in attr:
+            assert (
+                attr["center_point_box"] == 0
+            ), "Only support center_point_box = 0 in onnx importer right now"
+
+        if iou_threshold is None:
+            iou_threshold = _expr.const(0.0, dtype="float32")
+        if score_threshold is None:
+            score_threshold = _expr.const(0.0, dtype="float32")
+
+        def conditionally_squeeze_scalar(x):
+            rank = len(infer_shape(x))
+            assert rank <= 1, "nms thresholds must be scalars"
+            if rank == 1:
+                return _op.squeeze(x, [0])
+            return x
+
+        max_output_boxes_per_class = 
conditionally_squeeze_scalar(max_output_boxes_per_class)
+        iou_threshold = conditionally_squeeze_scalar(iou_threshold)
+        score_threshold = conditionally_squeeze_scalar(score_threshold)
+
+        ## prepare utility constants
+        zero = _op.const(np.array([0]), dtype="int64")
+        one = _op.const(np.array([1]), dtype="int64")
+        two = _op.const(np.array([2]), dtype="int64")
+        three = _op.const(np.array([3]), dtype="int64")
+        three_ones = _op.const(np.array([1, 1, 1]), dtype="int64")
+        four_ones = _op.const(np.array([1, 1, 1, 1]), dtype="int64")
+
+        ## First loop: split by class and perform NMS
+        # Create Loop Vars
+        i = _expr.var("i", shape=(1,), dtype="int64")
+        scores_var = _expr.var("scores_var", shape=(_ty.Any(), _ty.Any(), 
_ty.Any()), dtype=dtype)
+        boxes_var = _expr.var("boxes_var", shape=(_ty.Any(), _ty.Any(), 4), 
dtype=dtype)
+        max_output_boxes_per_class_var = _expr.var(
+            "max_output_boxes_per_class_var", shape=(), dtype="int64"
+        )
+        iou_threshold_var = _expr.var("iou_threshold_var", shape=(), 
dtype="float32")
+        score_threshold_var = _expr.var("score_threshold_var", shape=(), 
dtype="float32")
+        B = _expr.var("B", shape=(1,), dtype="int64")
+        C = _expr.var("C", shape=(1,), dtype="int64")
+        S = _expr.var("S", shape=(1,), dtype="int64")
+        # Outputs of first loop should be padded nms values shape (B, C, S, 3)
+        onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), 
_ty.Any(), 3), dtype="int64")
+        # and sizes of valid outputs, shape (B, C, 1)
+        nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 
1), dtype="int64")
+
+        def _first_cond(
+            i,
+            scores,
+            boxes,
+            B,
+            C,
+            S,
+            max_output_boxes_per_class,
+            iou_threshold,
+            score_threshold,
+            onnx_out,
+            nms_size_out,
+        ):
+            # Loop over classes, end when i == C
+            return _op.min(_op.less(i, C))
+
+        def _first_body(
+            i,
+            scores,
+            boxes,
+            B,
+            C,
+            S,
+            max_output_boxes_per_class,
+            iou_threshold,
+            score_threshold,
+            onnx_out,
+            nms_size_out,
+        ):
+            # slice to get current class
+            begin = _op.concatenate([zero, i, zero], axis=0)
+            end = _op.concatenate([B, i + one, S], axis=0)
+            class_scores = _op.strided_slice(scores, begin, end, three_ones)
+            class_scores = _op.expand_dims(_op.squeeze(class_scores, [1]), -1, 
1)
+            # combine scores and boxes
+            data = _op.concatenate([class_scores, boxes], axis=-1)
+
+            # get valid counts
+            ct, data, indices = _op.vision.get_valid_counts(
+                data, score_threshold=score_threshold, id_index=-1, 
score_index=0
+            )
+            # reason why using get_valid_counts is for inference performance
+            # ONNX NMS doesn't have parameter top_k
+            top_k = -1
+            # ONNX doesn't have class id for nms input
+            score_index = 0
+            # perform nms on current class
+            nms_ret = _op.vision.non_max_suppression(
+                data=data,
+                valid_count=ct,
+                indices=indices,
+                max_output_size=max_output_boxes_per_class,
+                iou_threshold=iou_threshold,
+                force_suppress=True,
+                top_k=top_k,
+                coord_start=1,
+                score_index=score_index,
+                id_index=-1,
+                return_indices=True,
+                invalid_to_bottom=False,
+            )
+            # partially prepare ONNX output format by labeling batch_num, 
class_id
+            nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1)
+            batch_num = _op.expand_dims(_op.arange(_op.squeeze(B, [0]), 
dtype="int64"), -1, 1)
+            batch_num = _op.broadcast_to(batch_num, _op.shape_of(nms_ret[0], 
dtype="int64"))
+            batch_num = _op.expand_dims(batch_num, -1, 1)
+            class_num = _op.broadcast_to(i, _op.shape_of(nms_padded_out, 
dtype="int64"))
+            new_onnx_out = _op.concatenate(
+                [batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1
+            )
+            new_onnx_out = _op.expand_dims(new_onnx_out, 1, 1)
+            # store valid nms outputs for this class
+            nms_size = _op.cast(nms_ret[1], "int64")
+            nms_size = _op.expand_dims(nms_size, 1, 1)
+            return [
+                i + one,
+                scores,
+                boxes,
+                B,
+                C,
+                S,
+                max_output_boxes_per_class,
+                iou_threshold,
+                score_threshold,
+                _op.concatenate([onnx_out, new_onnx_out], axis=1),
+                _op.concatenate([nms_size_out, nms_size], axis=1),
+            ]
+
+        # create the first loop
+        first_loop = _loops.while_loop(
+            _first_cond,
+            [
+                i,
+                scores_var,
+                boxes_var,
+                B,
+                C,
+                S,
+                max_output_boxes_per_class_var,
+                iou_threshold_var,
+                score_threshold_var,
+                onnx_out,
+                nms_size_out,
+            ],
+            _first_body,
+        )
+
+        ## Second loop slices outputs of the first loop for valid boxes and
+        ##  concats in the order ONNX wants
+        # Second inner Loop Vars
+        i = _expr.var("i", shape=(1,), dtype="int64")
+        j = _expr.var("j", shape=(1,), dtype="int64")
+        B = _expr.var("B", shape=(1,), dtype="int64")
+        C = _expr.var("C", shape=(1,), dtype="int64")
+        # Outputs of first loop should be padded nms values shape (B, C, 3)
+        onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), 
_ty.Any(), 3), dtype="int64")
+        # and sizes of valid outputs, shape (B, C, 1)
+        nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 
1), dtype="int64")
+        out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64")
+
+        def _inner_cond(i, j, C, onnx_out, nms_size, out):
+            # inner loop over number of classes
+            return _op.min(_op.less(j, C))
+
+        def _inner_body(i, j, C, onnx_out, nms_size, out):
+            # slice to get current batch and class for valid box indicator
+            start = _op.concatenate([i, j + one, zero], axis=0)
+            end = _op.concatenate([i + one, j + two, one], axis=0)
+            num_valid_boxes = _op.reshape(_op.strided_slice(nms_size, start, 
end, three_ones), [1])
+            # slice to get current batch, class, and valid outputs
+            start = _op.concatenate([i, j + one, zero, zero], axis=0)
+            end = _op.concatenate([i + one, j + two, num_valid_boxes, three], 
axis=0)
+            new_out = _op.squeeze(_op.strided_slice(onnx_out, start, end, 
four_ones), [0, 1])
+            return i, j + one, C, onnx_out, nms_size, _op.concatenate([out, 
new_out], axis=0)
+
+        inner_loop = _loops.while_loop(
+            _inner_cond, [i, j, C, onnx_out, nms_size_out, out], _inner_body
+        )
+
+        # Second Outer Loop Vars
+        i = _expr.var("i", shape=(1,), dtype="int64")
+        j = _expr.var("j", shape=(1,), dtype="int64")
+        B = _expr.var("B", shape=(1,), dtype="int64")
+        C = _expr.var("C", shape=(1,), dtype="int64")
+        # Outputs of first loop should be padded nms values shape (B, C, 3)
+        onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), 
_ty.Any(), 3), dtype="int64")
+        # and sizes of valid outputs, shape (B, C, 1)
+        nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 
1), dtype="int64")
+        out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64")
+
+        def _outer_cond(i, B, C, onnx_out, nms_size_out, out):
+            # Outer loop is over batch size
+            return _op.min(_op.less(i, B))
+
+        def _outer_body(i, B, C, onnx_out, nms_size_out, out):
+            # Outer loop just calls inner loop
+            init_count = _op.const(np.array([0]), dtype="int64")
+            inner_loop_vals = inner_loop(i, init_count, C, onnx_out, 
nms_size_out, out)
+            return i + one, B, C, onnx_out, nms_size_out, 
_expr.TupleGetItem(inner_loop_vals, 5)
+
+        # Create the second loop
+        outer_loop = _loops.while_loop(
+            _outer_cond, [i, B, C, onnx_out, nms_size_out, out], _outer_body
+        )
+
+        # Call the first loop, perform NMS
+        B, C, S = _op.split(_op.shape_of(scores, dtype="int64"), 3)
+        init_count = _op.const(np.array([0]), dtype="int64")
+        init_onnx_out = _op.const([1], dtype="int64")
+        init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, 
one, S, three], 0))
+        init_nms_size_out = _op.const([1], dtype="int64")
+        init_nms_size_out = _op.broadcast_to(init_nms_size_out, 
_op.concatenate([B, one, one], 0))
+        loop_vals = first_loop(
+            init_count,
+            scores,
+            boxes,
+            B,
+            C,
+            S,
+            max_output_boxes_per_class,
+            iou_threshold,
+            score_threshold,
+            init_onnx_out,
+            init_nms_size_out,
+        )
+        onnx_output = _expr.TupleGetItem(loop_vals, 9)
+        nms_size_output = _expr.TupleGetItem(loop_vals, 10)
+
+        # Call the second loop, rework outputs into correct form
+        init_count = _op.const(np.array([0]).astype("int64"), dtype="int64")
+        init_out = _op.const(np.array([]).reshape([0, 3]).astype("int64"), 
dtype="int64")
+        loop_vals = outer_loop(init_count, B, C, onnx_output, nms_size_output, 
init_out)
+
+        return _expr.TupleGetItem(loop_vals, 5)
+
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -2415,6 +2683,7 @@ def _get_convert_map(opset):
         # defs/vision
         "MaxRoiPool": MaxRoiPool.get_converter(opset),
         "RoiAlign": RoiAlign.get_converter(opset),
+        "NonMaxSuppression": NonMaxSuppression.get_converter(opset),
         # defs/reduction
         "ReduceMax": ReduceMax.get_converter(opset),
         "ReduceMin": ReduceMin.get_converter(opset),
diff --git a/python/tvm/relay/op/strategy/generic.py 
b/python/tvm/relay/op/strategy/generic.py
index e888eb4..10dc7b9 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -885,9 +885,11 @@ def wrap_compute_get_valid_counts(topi_compute):
     """wrap get_valid_counts topi compute"""
 
     def _compute_get_valid_counts(attrs, inputs, out_type):
-        score_threshold = get_const_float(attrs.score_threshold)
+        score_threshold = inputs[1]
         id_index = get_const_int(attrs.id_index)
         score_index = get_const_int(attrs.score_index)
+        if attrs.score_threshold is not None:
+            score_threshold = get_const_float(attrs.score_threshold)
         return topi_compute(inputs[0], score_threshold, id_index, score_index)
 
     return _compute_get_valid_counts
@@ -911,10 +913,12 @@ def wrap_compute_nms(topi_compute):
 
     def _compute_nms(attrs, inputs, out_type):
         max_output_size = inputs[3]
+        iou_threshold = inputs[4]
         if attrs.max_output_size is not None:
             max_output_size = attrs.max_output_size
+        if attrs.iou_threshold is not None:
+            iou_threshold = get_const_float(attrs.iou_threshold)
         return_indices = bool(get_const_int(attrs.return_indices))
-        iou_threshold = get_const_float(attrs.iou_threshold)
         force_suppress = bool(get_const_int(attrs.force_suppress))
         top_k = get_const_int(attrs.top_k)
         coord_start = get_const_int(attrs.coord_start)
diff --git a/python/tvm/relay/op/vision/nms.py 
b/python/tvm/relay/op/vision/nms.py
index 4366609..0a3df40 100644
--- a/python/tvm/relay/op/vision/nms.py
+++ b/python/tvm/relay/op/vision/nms.py
@@ -48,6 +48,8 @@ def get_valid_counts(data, score_threshold, id_index=0, 
score_index=1):
     out_indices: relay.Expr
         Indices in input data
     """
+    if not isinstance(score_threshold, expr.Expr):
+        score_threshold = expr.const(score_threshold, "float32")
     return expr.TupleWrapper(
         _make.get_valid_counts(data, score_threshold, id_index, score_index), 3
     )
@@ -94,7 +96,7 @@ def non_max_suppression(
         Max number of output valid boxes for each instance.
         Return all valid boxes if the value of max_output_size is less than 0.
 
-    iou_threshold : float, optional
+    iou_threshold : float or relay.Expr, optional
         Non-maximum suppression threshold.
 
     force_suppress : bool, optional
@@ -126,8 +128,10 @@ def non_max_suppression(
         If return_indices is True, return relay.Tuple of two 2-D tensors, with
         shape [batch_size, num_anchors] and [batch_size, num_valid_anchors] 
respectively.
     """
-    if isinstance(max_output_size, int):
+    if not isinstance(max_output_size, expr.Expr):
         max_output_size = expr.const(max_output_size, "int32")
+    if not isinstance(iou_threshold, expr.Expr):
+        iou_threshold = expr.const(iou_threshold, "float32")
     out = _make.non_max_suppression(
         data,
         valid_count,
diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index d51eb5c..d0915d9 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -52,71 +52,191 @@ def atomic_add(x, y):
     return tvm.tir.call_intrin(y.dtype, "tir.atomic_add", x, y)
 
 
-def rearrange_indices_out_ir(data, out, valid_box_count):
+def rearrange_indices_out_ir(data, output, valid_box_count):
     """Hybrid routine to rearrange nms output to
     move all valid entries to top.
 
     Parameters
     ----------
     data : tvm.te.Tensor or numpy NDArray
+        NMS output. 3-D tensor with shape
+        [batch_size, num_anchors, 6] or
+        [batch_size, num_anchors, 5], or 2-D
         tensor with shape [batch_size, num_anchors].
 
+    one: tvm.tir.const
+        Constant one with the same dtype as data.
+
+    batch_size: tvm.tir.IntImm or tvm.tir.Var
+        Batch size. We need to pass it in since hybrid script doesn't support
+        binding variable to symbolic dim.
+
+    num_anchors: tvm.tir.IntImm or tvm.tir.Var
+        Number of anchors.
 
     Returns
     -------
-    stmt : Stmt
-        The result IR statement.
+    output : tvm.te.Tensor or numpy NDArray
+        2-D tensor with shape [batch_size, num_anchors].
+
+    valid_box_count : tvm.te.Tensor or numpy NDArray
+        Tensor with shape [batch_size, 1], indicates
+        the valid number of boxes.
     """
     batch_size = data.shape[0]
     num_anchors = data.shape[1]
 
     ib = tvm.tir.ir_builder.create()
+
     data = ib.buffer_ptr(data)
-    out = ib.buffer_ptr(out)
     valid_box_count = ib.buffer_ptr(valid_box_count)
+    output = ib.buffer_ptr(output)
+
+    with ib.new_scope():
+        i = te.thread_axis("blockIdx.x")
+        ib.scope_attr(i, "thread_extent", batch_size)
+        valid_idx = ib.allocate("int32", (1), name="valid_idx", scope="local")
+        valid_idx[0] = 0
+        with ib.for_range(0, num_anchors, name="j") as j:
+            with ib.if_scope(data[i, j] >= 0):
+                with ib.if_scope(data[i, j] > num_anchors):
+                    output[i, valid_idx[0]] = 0
+                    valid_idx[0] = valid_idx[0] + 1
+                with ib.else_scope():
+                    output[i, valid_idx[0]] = data[i, j]
+                    valid_idx[0] = valid_idx[0] + 1
+            with ib.else_scope():
+                with ib.if_scope(data[i, j] < -num_anchors):
+                    output[i, valid_idx[0]] = 0
+                    valid_idx[0] = valid_idx[0] + 1
+            with ib.if_scope(j >= valid_idx[0]):
+                output[i, j] = -1
+        valid_box_count[i, 0] = valid_idx[0]
 
-    one_count = tvm.tir.const(1, dtype="int32")
-    atomic_add_return = ib.allocate(
-        valid_box_count.dtype, (batch_size,), name="atomic_add_return", 
scope="local"
-    )
+    return ib.get()
+
+
+def get_valid_boxes_ir(data, valid_boxes, score_threshold, id_index, 
score_index):
+    """Low level IR to identify bounding boxes given a score threshold.
+
+    Parameters
+    ----------
+    data : Buffer
+        Input data. 3-D Buffer with shape [batch_size, num_anchors, 
elem_length].
+
+    score_threshold : Buffer or float32
+        Lower limit of score for valid bounding boxes.
+
+    id_index : optional, int
+        index of the class categories, -1 to disable.
+
+    score_index: optional, int
+        Index of the scores/confidence of boxes.
+
+    Returns
+    -------
+    valid_boxes: Buffer
+        2D Buffer  indicating valid boxes with shape [batch_size, num_anchors].
+
+    """
+    batch_size = data.shape[0]
+    num_anchors = data.shape[1]
+    elem_length = data.shape[2]
+
+    ib = tvm.tir.ir_builder.create()
+
+    data = ib.buffer_ptr(data)
+
+    valid_boxes = ib.buffer_ptr(valid_boxes)
+    if isinstance(score_threshold, float):
+        score_threshold = tvm.tir.FloatImm("float32", score_threshold)
+    id_index = tvm.tir.IntImm("int32", id_index)
+    score_index = tvm.tir.IntImm("int32", score_index)
 
     max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
-    nthread_tx = max_threads
-    tx = te.thread_axis("threadIdx.x")
-    ib.scope_attr(tx, "thread_extent", nthread_tx)
-    len_inner_for = (batch_size * num_anchors) // nthread_tx + 2
-
-    idxd = tvm.tir.indexdiv
-    idxm = tvm.tir.indexmod
-
-    with ib.for_range(0, len_inner_for, name="i") as i:
-        idx = tx * len_inner_for + i
-        batch_idx = idxd(idx, num_anchors)
-        with ib.if_scope(idx < batch_size):
-            valid_box_count[idx] = 0
-        with ib.if_scope(idx < batch_size * num_anchors):
-            with ib.if_scope(data[idx] >= 0):
-                atomic_add_return[batch_idx] = atomic_add(
-                    tvm.tir.call_intrin("handle", "tir.address_of", 
valid_box_count[batch_idx]),
-                    one_count,
+    with ib.new_scope():
+        nthread_tx = max_threads
+        nthread_bx = num_anchors // max_threads + 1
+        nthread_by = batch_size
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        by = te.thread_axis("blockIdx.y")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        ib.scope_attr(by, "thread_extent", nthread_by)
+        tid = bx * max_threads + tx
+
+        with ib.if_scope(tid < num_anchors):
+            i = by
+            j = tid
+            score = data[(i * num_anchors + j) * elem_length + score_index]
+            with ib.if_scope(
+                tvm.tir.all(
+                    score > score_threshold,
+                    tvm.tir.any(
+                        id_index < 0, data[(i * num_anchors + j) * elem_length 
+ id_index] >= 0
+                    ),
                 )
-                out[batch_idx * num_anchors + atomic_add_return[batch_idx]] = 
data[idx]
-            with ib.if_scope(tvm.tir.any(data[idx] > num_anchors, data[idx] < 
-num_anchors)):
-                atomic_add_return[batch_idx] = atomic_add(
-                    tvm.tir.call_intrin("handle", "tir.address_of", 
valid_box_count[batch_idx]),
-                    one_count,
-                )
-                out[batch_idx * num_anchors + atomic_add_return[batch_idx]] = 0
+            ):
+                valid_boxes[i * num_anchors + j] = 1
+            with ib.else_scope():
+                valid_boxes[i * num_anchors + j] = 0
+    return ib.get()
+
+
+def get_valid_indices_ir(valid_boxes, valid_count, valid_indices):
+    """Low level IR to get the ouput indices of valid boxes
+    and the count of valid boxes
+
+    Parameters
+    ----------
+    valid_boxes: Buffer
+        2D Buffer  indicating valid boxes with shape [batch_size, num_anchors].
+
+    Returns
+    -------
+    valid_count: Buffer
+        1D Buffer of number of valid boxes per batch [batch_size].
+
+    valid_indices: Buffer
+        2D Buffer indicating output sorted indcies of valid boxes [batch_size, 
num_anchors].
+    """
+    batch_size = valid_boxes.shape[0]
+    num_anchors = valid_boxes.shape[1]
 
-            with ib.if_scope(idxm(idx, num_anchors) >= 
valid_box_count[batch_idx]):
-                out[idx] = -1
+    ib = tvm.tir.ir_builder.create()
+
+    valid_boxes = ib.buffer_ptr(valid_boxes)
+
+    valid_count = ib.buffer_ptr(valid_count)
+    valid_indices = ib.buffer_ptr(valid_indices)
 
+    max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+    with ib.new_scope():
+        nthread_tx = max_threads
+        nthread_bx = batch_size // max_threads + 1
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        tid = bx * max_threads + tx
+        # TODO(mbrookhart): Parallelize the sum and cumsum here
+        current_index = ib.allocate("int32", (1,), name="current_index", 
scope="local")
+        with ib.if_scope(tid < batch_size):
+            current_index[0] = 0
+            valid_count[tid] = 0
+            with ib.for_range(0, num_anchors) as j:
+                idx = tid * num_anchors + j
+                valid_count[tid] = valid_count[tid] + valid_boxes[idx]
+                with ib.if_scope(valid_boxes[idx] == 1):
+                    valid_indices[idx] = current_index[0]
+                    current_index[0] = current_index[0] + 1
+                with ib.else_scope():
+                    valid_indices[idx] = -1
     return ib.get()
 
 
-def get_valid_counts_ir(
-    data, valid_count, out, out_indices, score_threshold, id_index, score_index
-):
+def get_valid_counts_ir(data, valid_indices, out, out_indices):
     """Low level IR to get valid count of bounding boxes
     given a score threshold. Also prepares to move valid boxes to the
     top of input data.
@@ -126,25 +246,16 @@ def get_valid_counts_ir(
     data : Buffer
         Input data. 3-D Buffer with shape [batch_size, num_anchors, 
elem_length].
 
-    valid_count : Buffer
-        1D buffer for valid number of boxes with shape [batch_size, ].
-
-    flag : Buffer
+    valid_indices: Buffer
         2D Buffer of flag indicating valid data with shape [batch_size, 
num_anchors].
 
-    score_threshold : float32
-        Lower limit of score for valid bounding boxes.
-
-    id_index : optional, int
-        index of the class categories, -1 to disable.
-
-    score_index: optional, int
-        Index of the scores/confidence of boxes.
-
     Returns
     -------
-    stmt : Stmt
-        The result IR statement.
+    out : Buffer
+        Sorted valid boxes
+
+    out_indices : Buffer
+        Incidices of valid boxes in original data
     """
     batch_size = data.shape[0]
     num_anchors = data.shape[1]
@@ -154,50 +265,51 @@ def get_valid_counts_ir(
 
     data = ib.buffer_ptr(data)
 
-    valid_count = ib.buffer_ptr(valid_count)
+    valid_indices = ib.buffer_ptr(valid_indices)
     out = ib.buffer_ptr(out)
     out_indices = ib.buffer_ptr(out_indices)
-    atomic_add_return = ib.allocate(
-        valid_count.dtype, (1,), name="atomic_add_return", scope="local"
-    )
-    one_count = tvm.tir.const(1, dtype=valid_count.dtype)
     one = tvm.tir.const(1, dtype=out.dtype)
-    score_threshold = tvm.tir.FloatImm("float32", score_threshold)
-    id_index = tvm.tir.IntImm("int32", id_index)
-    score_index = tvm.tir.IntImm("int32", score_index)
 
     max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
     nthread_tx = max_threads
-    nthread_bx = batch_size * num_anchors // max_threads + 1
-    tx = te.thread_axis("threadIdx.x")
-    bx = te.thread_axis("blockIdx.x")
-    ib.scope_attr(tx, "thread_extent", nthread_tx)
-    ib.scope_attr(bx, "thread_extent", nthread_bx)
-    tid = bx * max_threads + tx
-    idxd = tvm.tir.indexdiv
-
-    # initialize valid_count
-    with ib.if_scope(tid < batch_size):
-        valid_count[tid] = 0
-    with ib.if_scope(tid < batch_size * num_anchors):
-        i = idxd(tid, num_anchors)
-        with ib.if_scope(
-            tvm.tir.all(
-                data[tid * elem_length + score_index] > score_threshold,
-                tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] 
>= 0),
-            )
-        ):
-            atomic_add_return[0] = atomic_add(
-                tvm.tir.call_intrin("handle", "tir.address_of", 
valid_count[i]), one_count
-            )
-            with ib.for_range(0, elem_length) as k:
-                out[tid * elem_length + k] = data[tid * elem_length + k]
-                out_indices[tid + k] = tid + k
-        with ib.else_scope():
-            with ib.for_range(0, elem_length) as k:
-                out[tid * elem_length + k] = -one
-                out_indices[tid + k] = -one_count
-
+    nthread_bx = num_anchors // max_threads + 1
+    nthread_by = batch_size
+    nthread_bz = elem_length
+    with ib.new_scope():
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        by = te.thread_axis("blockIdx.y")
+        bz = te.thread_axis("blockIdx.z")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        ib.scope_attr(by, "thread_extent", nthread_by)
+        ib.scope_attr(bz, "thread_extent", nthread_bz)
+        tid = bx * max_threads + tx
+        with ib.if_scope(tid < num_anchors):
+            i = by
+            j = tid
+            k = bz
+            out[(i * num_anchors + j) * elem_length + k] = -one
+            out_indices[i * num_anchors + j] = -1
+    with ib.new_scope():
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        by = te.thread_axis("blockIdx.y")
+        bz = te.thread_axis("blockIdx.z")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        ib.scope_attr(by, "thread_extent", nthread_by)
+        ib.scope_attr(bz, "thread_extent", nthread_bz)
+        tid = bx * max_threads + tx
+        with ib.if_scope(tid < num_anchors):
+            i = by
+            j = tid
+            k = bz
+            with ib.if_scope(valid_indices[i, tid] >= 0):
+                out[(i * num_anchors + valid_indices[i, tid]) * elem_length + 
k] = data[
+                    (i * num_anchors + j) * elem_length + k
+                ]
+                out_indices[i * num_anchors + valid_indices[i, tid]] = j
     return ib.get()
 
 
@@ -210,7 +322,7 @@ def get_valid_counts(data, score_threshold=0, id_index=0, 
score_index=1):
     data : tvm.te.Tensor
         Input data. 3-D tensor with shape [batch_size, num_anchors, 
elem_length].
 
-    score_threshold : optional, float
+    score_threshold : optional, tvm.te.Tensor or float
         Lower limit of score for valid bounding boxes.
 
     id_index : optional, int
@@ -230,23 +342,51 @@ def get_valid_counts(data, score_threshold=0, id_index=0, 
score_index=1):
     batch_size = data.shape[0]
     num_anchors = data.shape[1]
     data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", 
data_alignment=8)
+    valid_boxes_buf = tvm.tir.decl_buffer(
+        (batch_size, num_anchors), "int32", "valid_boxes_buf", data_alignment=8
+    )
+    valid_boxes = te.extern(
+        [(batch_size, num_anchors)],
+        [data],
+        lambda ins, outs: get_valid_boxes_ir(
+            ins[0], outs[0], score_threshold, id_index, score_index
+        ),
+        dtype=["int32"],
+        in_buffers=[data_buf],
+        out_buffers=[valid_boxes_buf],
+        name="get_valid_boxes",
+        tag="get_valid_boxes_gpu",
+    )
+
+    valid_indices_buf = tvm.tir.decl_buffer(
+        (batch_size, num_anchors), "int32", "valid_indices_buf", 
data_alignment=8
+    )
     valid_count_buf = tvm.tir.decl_buffer(
         (batch_size,), "int32", "valid_count_buf", data_alignment=8
     )
+    valid_count, valid_indices = te.extern(
+        [(batch_size,), (batch_size, num_anchors)],
+        [valid_boxes],
+        lambda ins, outs: get_valid_indices_ir(ins[0], outs[0], outs[1]),
+        dtype=["int32"],
+        in_buffers=[valid_boxes_buf],
+        out_buffers=[valid_count_buf, valid_indices_buf],
+        name="get_valid_indices",
+        tag="get_valid_indices_gpu",
+    )
+
     out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", 
data_alignment=8)
     out_indices_buf = tvm.tir.decl_buffer(
         (batch_size, num_anchors), "int32", "out_buf", data_alignment=8
     )
 
-    valid_count, out, out_indices = te.extern(
-        [(batch_size,), data.shape, (batch_size, num_anchors)],
-        [data],
-        lambda ins, outs: get_valid_counts_ir(
-            ins[0], outs[0], outs[1], outs[2], score_threshold, id_index, 
score_index
-        ),
+    out, out_indices = te.extern(
+        [data.shape, (batch_size, num_anchors)],
+        [data, valid_indices],
+        lambda ins, outs: get_valid_counts_ir(ins[0], ins[1], outs[0], 
outs[1]),
         dtype=["int32", data.dtype],
-        in_buffers=[data_buf],
-        out_buffers=[valid_count_buf, out_buf, out_indices_buf],
+        in_buffers=[data_buf, valid_indices_buf],
+        out_buffers=[out_buf, out_indices_buf],
         name="get_valid_counts",
         tag="get_valid_counts_gpu",
     )
@@ -277,12 +417,19 @@ def nms_ir(
     data : Buffer
         Buffer of output boxes with class and score.
 
-    sort_index : Buffer
+    sorted_index : Buffer
         Buffer of output box indexes sorted by score.
 
     valid_count : Buffer
         Buffer of number of valid output boxes.
 
+    indices : Buffer
+        indices in original tensor, with shape [batch_size, num_anchors],
+        represents the index of box in original data. It could be the third
+        output out_indices of get_valid_counts. The values in the second
+        dimension are like the output of arange(num_anchors) if 
get_valid_counts
+        is not used before non_max_suppression.
+
     out : Buffer
         Output buffer.
 
@@ -308,33 +455,50 @@ def nms_ir(
     score_index : optional, int
         Index of the scores/confidence of boxes.
 
+    return_indices : boolean
+        Whether to return box indices in input data.
+
     Returns
     -------
     stmt : Stmt
         The result IR statement.
     """
 
-    def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
-        """Calculate overlap of two boxes."""
-        w = tvm.te.max(
-            0.0,
-            tvm.te.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
-            - tvm.te.max(out_tensor[box_a_idx], out_tensor[box_b_idx]),
+    def get_boundaries(output, box_idx):
+        l = tvm.te.min(
+            output[box_idx],
+            output[box_idx + 2],
+        )
+        t = tvm.te.min(
+            output[box_idx + 1],
+            output[box_idx + 3],
         )
-        h = tvm.te.max(
-            0.0,
-            tvm.te.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
-            - tvm.te.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]),
+        r = tvm.te.max(
+            output[box_idx],
+            output[box_idx + 2],
         )
-        i = w * h
-        u = (
-            (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx])
-            * (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1])
-            + (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx])
-            * (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1])
-            - i
+        b = tvm.te.max(
+            output[box_idx + 1],
+            output[box_idx + 3],
         )
-        return tvm.tir.Select(u <= 0.0, 0.0, i / u)
+        return l, t, r, b
+
+    def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
+        """Calculate overlap of two boxes."""
+        a_l, a_t, a_r, a_b = get_boundaries(out_tensor, box_a_idx)
+        b_l, b_t, b_r, b_b = get_boundaries(out_tensor, box_b_idx)
+
+        # Overlapping width and height
+        w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l))
+        h = tvm.te.max(0.0, tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t))
+
+        # Overlapping area
+        area = h * w
+
+        # total area of the figure formed by box a and box b
+        # except for overlapping area
+        u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area
+        return tvm.tir.Select(u <= 0.0, 0.0, area / u)
 
     batch_size = data.shape[0]
     num_anchors = data.shape[1]
@@ -345,60 +509,64 @@ def nms_ir(
     data = ib.buffer_ptr(data)
     sorted_index = ib.buffer_ptr(sorted_index)
     valid_count = ib.buffer_ptr(valid_count)
+    indices = ib.buffer_ptr(indices)
     out = ib.buffer_ptr(out)
     box_indices = ib.buffer_ptr(box_indices)
-    indices = ib.buffer_ptr(indices)
-    num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", 
scope="local")
-
-    max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
-    nthread_tx = max_threads
-    nthread_bx = num_anchors // max_threads + 1
-    tx = te.thread_axis("threadIdx.x")
-    bx = te.thread_axis("blockIdx.x")
-    ib.scope_attr(tx, "thread_extent", nthread_tx)
-    ib.scope_attr(bx, "thread_extent", nthread_bx)
-    j = bx * max_threads + tx
 
-    iou_threshold = tvm.tir.FloatImm("float32", iou_threshold)
+    if isinstance(iou_threshold, float):
+        iou_threshold = tvm.tir.FloatImm("float32", iou_threshold)
     top_k = tvm.tir.IntImm("int32", top_k)
     coord_start = tvm.tir.IntImm("int32", coord_start)
     id_index = tvm.tir.IntImm("int32", id_index)
     score_index = tvm.tir.IntImm("int32", score_index)
     force_suppress = tvm.tir.IntImm("int32", 1 if force_suppress else 0)
 
-    with ib.for_range(0, batch_size, for_type="unroll") as i:
+    max_threads = 
int(tvm.target.Target.current(allow_none=False).max_num_threads)
+
+    with ib.new_scope():
+        nthread_by = batch_size
+        by = te.thread_axis("blockIdx.y")
+        ib.scope_attr(by, "thread_extent", nthread_by)
+        i = by
         base_idx = i * num_anchors * box_data_length
         with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
             # Reorder output
             nkeep = if_then_else(
                 tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, 
valid_count[i]
             )
-            with ib.if_scope(j < nkeep):
+            with ib.for_range(0, nkeep) as j:
                 with ib.for_range(0, box_data_length) as k:
                     out[(base_idx + j * box_data_length + k)] = data[
                         (base_idx + sorted_index[i * num_anchors + j] * 
box_data_length + k)
                     ]
                 box_indices[i * num_anchors + j] = sorted_index[i * 
num_anchors + j]
             with ib.if_scope(tvm.tir.all(top_k > 0, top_k < valid_count[i])):
-                with ib.if_scope(j < valid_count[i] - nkeep):
+                with ib.for_range(0, valid_count[i] - nkeep) as j:
                     with ib.for_range(0, box_data_length) as k:
                         out[(base_idx + (j + nkeep) * box_data_length + k)] = 
-1.0
                     box_indices[i * num_anchors + (j + nkeep)] = -1
+    with ib.new_scope():
+        nthread_by = batch_size
+        by = te.thread_axis("blockIdx.y")
+        ib.scope_attr(by, "thread_extent", nthread_by)
+        i = by
+        base_idx = i * num_anchors * box_data_length
+        with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
             # Apply nms
-            with ib.for_range(0, valid_count[i]) as k:
-                offset_k = k * box_data_length
-                with ib.if_scope(
-                    tvm.tir.all(
-                        out[base_idx + offset_k + score_index] > 0,
-                        tvm.tir.any(id_index < 0, out[base_idx + offset_k + 
id_index] >= 0),
-                    )
-                ):
-                    with ib.if_scope(j < valid_count[i]):
+            with ib.for_range(0, valid_count[i]) as j:
+                with ib.for_range(0, j) as k:
+                    offset_k = k * box_data_length
+                    with ib.if_scope(
+                        tvm.tir.all(
+                            out[base_idx + offset_k + score_index] > 0,
+                            tvm.tir.any(id_index < 0, out[base_idx + offset_k 
+ id_index] >= 0),
+                        )
+                    ):
                         offset_j = j * box_data_length
                         with ib.if_scope(
                             tvm.tir.all(
                                 j > k,
-                                out[base_idx + offset_j + score_index] > 0,
+                                out[base_idx + offset_k + score_index] > 0,
                                 tvm.tir.any(id_index < 0, out[base_idx + 
offset_j + id_index] >= 0),
                                 tvm.tir.any(
                                     force_suppress > 0,
@@ -418,21 +586,47 @@ def nms_ir(
                                 with ib.if_scope(id_index >= 0):
                                     out[base_idx + offset_j + id_index] = -1.0
                                 box_indices[i * num_anchors + j] = -1
+    with ib.new_scope():
+        nthread_tx = max_threads
+        nthread_bx = num_anchors // max_threads + 1
+        nthread_by = batch_size
+        nthread_bz = box_data_length
+        tx = te.thread_axis("threadIdx.x")
+        bx = te.thread_axis("blockIdx.x")
+        by = te.thread_axis("blockIdx.y")
+        bz = te.thread_axis("blockIdx.z")
+        ib.scope_attr(tx, "thread_extent", nthread_tx)
+        ib.scope_attr(bx, "thread_extent", nthread_bx)
+        ib.scope_attr(by, "thread_extent", nthread_by)
+        ib.scope_attr(bz, "thread_extent", nthread_bz)
+        tid = bx * max_threads + tx
+        i = by
+        j = tid
+        k = bz
+        base_idx = i * num_anchors * box_data_length
+        with ib.if_scope(tvm.tir.all(iou_threshold > 0, valid_count[i] > 0)):
+            pass
         with ib.else_scope():
             with ib.if_scope(j < valid_count[i]):
                 offset_j = j * box_data_length
-                with ib.for_range(0, box_data_length) as k:
-                    out[(base_idx + offset_j + k)] = data[base_idx + offset_j 
+ k]
+                out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
                 box_indices[i * num_anchors + j] = j
+
+    with ib.new_scope():
+        num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", 
scope="local")
+        bx = te.thread_axis("blockIdx.x")
+        ib.scope_attr(bx, "thread_extent", batch_size)
+        i = bx
+        base_idx = i * num_anchors * box_data_length
         # Set invalid entry to be -1
-        with ib.if_scope(j < num_anchors - valid_count[i]):
+        with ib.for_range(0, num_anchors - valid_count[i]) as j:
             with ib.for_range(0, box_data_length) as k:
                 out[base_idx + (j + valid_count[i]) * box_data_length + k] = 
-1.0
             box_indices[i * num_anchors + j + valid_count[i]] = -1
         # Only return max_output_size number of valid boxes
         num_valid_boxes[0] = 0
         with ib.if_scope(max_output_size > 0):
-            with ib.if_scope(j < valid_count[i]):
+            with ib.for_range(0, valid_count[i]) as j:
                 offset_j = j * box_data_length
                 with ib.if_scope(out[base_idx + offset_j] >= 0):
                     with ib.if_scope(num_valid_boxes[0] == max_output_size):
@@ -442,11 +636,20 @@ def nms_ir(
                     with ib.else_scope():
                         num_valid_boxes[0] += 1
 
-        if return_indices:
-            with ib.if_scope(j < valid_count[i]):
-                box_idx = box_indices[i * num_anchors + j]
-                with ib.if_scope(box_idx >= 0):
-                    box_indices[i * num_anchors + j] = indices[i * num_anchors 
+ box_idx]
+    if return_indices:
+        with ib.new_scope():
+            nthread_tx = max_threads
+            nthread_bx = batch_size // max_threads + 1
+            tx = te.thread_axis("threadIdx.x")
+            bx = te.thread_axis("blockIdx.x")
+            ib.scope_attr(tx, "thread_extent", nthread_tx)
+            ib.scope_attr(bx, "thread_extent", nthread_bx)
+            i = bx * max_threads + tx
+            with ib.if_scope(i < batch_size):
+                with ib.for_range(0, valid_count[i]) as j:
+                    idx = box_indices[i * num_anchors + j]
+                    with ib.if_scope(idx >= 0):
+                        box_indices[i * num_anchors + j] = indices[i * 
num_anchors + idx]
 
     return ib.get()
 
@@ -486,11 +689,11 @@ def non_max_suppression(
         second dimension are like the output of arange(num_anchors)
         if get_valid_counts is not used before non_max_suppression.
 
-    max_output_size : optional, int
+    max_output_size : optional, tvm.te.Tensor or int
         Max number of output valid boxes for each instance.
         By default all valid boxes are returned.
 
-    iou_threshold : optional, float
+    iou_threshold : optional, tvm.te.Tensor or float
         Non-maximum suppression threshold.
 
     force_suppress : optional, boolean
@@ -570,6 +773,8 @@ def non_max_suppression(
         sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", 
data_alignment=8
     )
 
+    indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, 
"indices_buf", data_alignment=8)
+
     data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", 
data_alignment=8)
     indices_buf = tvm.tir.decl_buffer(indices.shape, indices.dtype, 
"indices_buf", data_alignment=8)
 
@@ -597,19 +802,19 @@ def non_max_suppression(
         name="nms",
         tag="nms",
     )
-
     if return_indices:
-        out_buf = tvm.tir.decl_buffer(
-            box_indices.shape, box_indices.dtype, "out_buf", data_alignment=8
-        )
+        out_shape = box_indices.shape
+        valid_box_count_shape = [box_indices.shape[0], 1]
+        valid_box_count = tvm.tir.decl_buffer(valid_box_count_shape, "int32", 
"valid_box_count")
+        output = tvm.tir.decl_buffer(box_indices.shape, "int32", "output")
         return te.extern(
-            [box_indices.shape, (batch_size, 1)],
+            [out_shape, valid_box_count_shape],
             [box_indices],
             lambda ins, outs: rearrange_indices_out_ir(ins[0], outs[0], 
outs[1]),
-            dtype=[box_indices.dtype, valid_count.dtype],
-            in_buffers=[out_buf],
-            name="rearrange_indices_out",
-            tag="rearrange_indices_out",
+            dtype="int32",
+            out_buffers=[output, valid_box_count],
+            name="rearrange_indices_out_gpu",
+            tag="rearrange_indices_out_gpu",
         )
 
     return out
diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py
index 0094ef1..329f0fb 100644
--- a/python/tvm/topi/cuda/sort.py
+++ b/python/tvm/topi/cuda/sort.py
@@ -104,9 +104,9 @@ def sort_ir(data, values_out, axis, is_ascend, 
indices_out=None):
     nthread_bx = shape[axis] // max_threads + 1
 
     tx = te.thread_axis("threadIdx.x")
-    bx = te.thread_axis("vthread")
+    bx = te.thread_axis("blockIdx.x")
     ib.scope_attr(tx, "thread_extent", nthread_tx)
-    ib.scope_attr(bx, "virtual_thread", nthread_bx)
+    ib.scope_attr(bx, "thread_extent", nthread_bx)
     tid = bx * nthread_tx + tx
     temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", 
scope="local")
     if indices_out is not None:
@@ -202,9 +202,9 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
     nthread_tx = max_threads
     nthread_bx = size // max_threads + 1
     tx = te.thread_axis("threadIdx.x")
-    bx = te.thread_axis("vthread")
+    bx = te.thread_axis("blockIdx.x")
     ib.scope_attr(tx, "thread_extent", nthread_tx)
-    ib.scope_attr(bx, "virtual_thread", nthread_bx)
+    ib.scope_attr(bx, "thread_extent", nthread_bx)
     tid = bx * nthread_tx + tx
     temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
     temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py
index b076fde..035d19f 100644
--- a/python/tvm/topi/vision/nms.py
+++ b/python/tvm/topi/vision/nms.py
@@ -133,7 +133,7 @@ def hybrid_get_valid_counts(
         Input data. 3-D tensor with shape [batch_size, num_anchors, 6]
         or [batch_size, num_anchors, 5].
 
-    score_threshold : tvm.tir.const
+    score_threshold : tvm.te.Tensor
         Lower limit of score for valid bounding boxes.
 
     id_index : tvm.tir.const
@@ -213,12 +213,13 @@ def get_valid_counts(data, score_threshold=0, id_index=0, 
score_index=1):
     out_indices: tvm.te.Tensor or numpy NDArray
         Related index in input data.
     """
-    score_threshold_const = tvm.tir.const(score_threshold, data.dtype)
+    if isinstance(score_threshold, float):
+        score_threshold = tvm.tir.const(score_threshold, dtype=data.dtype)
     id_index_const = tvm.tir.const(id_index, "int32")
     score_index_const = tvm.tir.const(score_index, "int32")
     return hybrid_get_valid_counts(
         data,
-        score_threshold_const,
+        score_threshold,
         id_index_const,
         score_index_const,
         tvm.tir.const(1, data.dtype),
@@ -281,7 +282,7 @@ def hybrid_nms(
         Max number of output valid boxes for each instance.
         Return all valid boxes if max_output_size < 0.
 
-    iou_threshold : tvm.tir.const
+    iou_threshold : tvm.te.Tensor
         Overlapping(IoU) threshold to suppress object with smaller score.
 
     force_suppress : tvm.tir.const
@@ -494,7 +495,7 @@ def non_max_suppression(
         Max number of output valid boxes for each instance.
         Return all valid boxes if the value of max_output_size is less than 0.
 
-    iou_threshold : optional, float
+    iou_threshold : optional, float or tvm.te.Tensor
         Non-maximum suppression threshold.
 
     force_suppress : optional, boolean
@@ -554,6 +555,8 @@ def non_max_suppression(
     num_anchors = data.shape[1]
     if isinstance(max_output_size, int):
         max_output_size = tvm.tir.const(max_output_size, dtype="int32")
+    if isinstance(iou_threshold, float):
+        iou_threshold = tvm.tir.const(iou_threshold, dtype=data.dtype)
     score_axis = score_index
     score_shape = (batch_size, num_anchors)
     score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis])
@@ -567,7 +570,7 @@ def non_max_suppression(
         batch_size,
         num_anchors,
         max_output_size,
-        tvm.tir.const(iou_threshold, dtype=data.dtype),
+        iou_threshold,
         tvm.tir.const(force_suppress, dtype="bool"),
         tvm.tir.const(top_k, dtype="int32"),
         tvm.tir.const(coord_start, dtype="int32"),
@@ -577,6 +580,7 @@ def non_max_suppression(
         zero=tvm.tir.const(0, dtype=data.dtype),
         one=tvm.tir.const(1, dtype=data.dtype),
     )
+
     if return_indices:
         return hybrid_rearrange_indices_out(
             box_indices,
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index f652644..bed2510 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -1070,6 +1070,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, 
const TargetsMap& targe
 
   pass_seqs.push_back(transform::FuseOps());
   pass_seqs.push_back(transform::ToANormalForm());
+  pass_seqs.push_back(transform::InferType());
   pass_seqs.push_back(transform::LambdaLift());
   pass_seqs.push_back(transform::InlinePrimitives());
 
diff --git a/src/relay/backend/vm/lambda_lift.cc 
b/src/relay/backend/vm/lambda_lift.cc
index f21d096..8e9cc62 100644
--- a/src/relay/backend/vm/lambda_lift.cc
+++ b/src/relay/backend/vm/lambda_lift.cc
@@ -111,6 +111,15 @@ class LambdaLifter : public ExprMutator {
       }
       captured_vars.push_back(var);
     }
+
+    Array<Var> typed_captured_vars;
+    Map<Var, Expr> rebinding_map;
+    for (auto free_var : captured_vars) {
+      auto var = Var(free_var->name_hint(), free_var->checked_type());
+      typed_captured_vars.push_back(var);
+      rebinding_map.Set(free_var, var);
+    }
+
     if (recursive) {
       if (!captured_vars.empty()) {
         Array<Expr> fvs;
@@ -122,6 +131,7 @@ class LambdaLifter : public ExprMutator {
         lambda_map_.emplace(letrec_.back(), global);
       }
     }
+
     auto body = Downcast<Function>(ExprMutator::VisitExpr_(func_node));
 
     // When performing this optimization there are two cases.
@@ -150,7 +160,25 @@ class LambdaLifter : public ExprMutator {
     if (captured_vars.size() == 0 && free_type_vars.size() == 0) {
       lifted_func = Function(body->params, body->body, body->ret_type, 
body->type_params);
     } else {
-      lifted_func = Function(captured_vars, body, 
func->func_type_annotation(), free_type_vars);
+      // When a closure is locally bound in a program, we have its full type 
information
+      // avalible to us.
+      //
+      // If we lift the closure out of its bound context it may have free 
variables which
+      // do not have type annotations.
+      //
+      // In this case we first type check the program assigning a type to all 
sub-expressions.
+      //
+      // We then change the un-annotated free variables into annotated free 
variables, use
+      // bind to go from unannotated free variables -> annotated free 
variables and then
+      // construct the "closure" function with fully annotated arguments, no 
longer relying
+      // on type inference.
+      auto before = Downcast<Function>(body)->params.size();
+      auto rebound_body = Function(func->params, Bind(body->body, 
rebinding_map), func->ret_type,
+                                   func->type_params, func->attrs, func->span);
+      auto after = Downcast<Function>(rebound_body)->params.size();
+      CHECK_EQ(before, after);
+      lifted_func =
+          Function(typed_captured_vars, rebound_body, 
func->func_type_annotation(), free_type_vars);
       lifted_func = MarkClosure(lifted_func);
     }
 
@@ -164,6 +192,7 @@ class LambdaLifter : public ExprMutator {
       global = module_->GetGlobalVar(name);
     } else {
       // Add the lifted function to the module.
+      std::cout << AsText(lifted_func) << std::endl;
       module_->Add(global, lifted_func);
     }
 
diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h
index 4173d57..34aaf46 100644
--- a/src/relay/op/tensor/transform.h
+++ b/src/relay/op/tensor/transform.h
@@ -44,21 +44,29 @@ template <typename AttrType>
 bool ConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
                     const TypeReporter& reporter) {
   // types: [data, result]
-  ICHECK_EQ(types.size(), 2);
+  ICHECK_EQ(types.size(), 2) << "the arity of concatenate is 2, not " << 
types.size();
   /* If we receive a tuple we can continue, if we receive
    * anything but an incomplete type we should signal an
    * error.
    */
   const auto* tensor_tuple = types[0].as<TupleTypeNode>();
   if (tensor_tuple == nullptr) {
-    throw Error(
-        ErrorBuilder() << "concatenate requires a tuple of tensors as the 
first argument, found "
-                       << PrettyPrint(types[0]));
+    reporter->GetDiagCtx().EmitFatal(
+        Diagnostic::Error(reporter->GetSpan())
+        << "concatenate requires a tuple of tensors as the first argument, 
found "
+        << PrettyPrint(types[0]));
+    return false;
   } else if (types[0].as<IncompleteTypeNode>() != nullptr) {
     return false;
   }
 
   const auto* param = attrs.as<AttrType>();
+  if (param == nullptr) {
+    reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
+                                     << "the call attributes are not defined");
+    return false;
+  }
+
   if (tensor_tuple->fields[0].as<IncompleteTypeNode>()) {
     return false;
   }
diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc
index 76fdf28..9316fec 100644
--- a/src/relay/op/vision/nms.cc
+++ b/src/relay/op/vision/nms.cc
@@ -31,8 +31,9 @@ TVM_REGISTER_NODE_TYPE(GetValidCountsAttrs);
 
 bool GetValidCountRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
                       const TypeReporter& reporter) {
-  ICHECK_EQ(types.size(), 2);
+  ICHECK_EQ(types.size(), 3);
   const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
   const auto& dshape = data->shape;
   ICHECK_EQ(dshape.size(), 3) << "Input data should be 3-D.";
 
@@ -44,17 +45,16 @@ bool GetValidCountRel(const Array<Type>& types, int 
num_inputs, const Attrs& att
   fields.push_back(TensorType(oshape_indices, DataType::Int(32)));
 
   // assign output type
-  reporter->Assign(types[1], TupleType(Array<Type>(fields)));
+  reporter->Assign(types[2], TupleType(Array<Type>(fields)));
   return true;
 }
 
-Expr MakeGetValidCounts(Expr data, double score_threshold, int id_index, int 
score_index) {
+Expr MakeGetValidCounts(Expr data, Expr score_threshold, int id_index, int 
score_index) {
   auto attrs = make_object<GetValidCountsAttrs>();
-  attrs->score_threshold = score_threshold;
   attrs->id_index = id_index;
   attrs->score_index = score_index;
   static const Op& op = Op::Get("vision.get_valid_counts");
-  return Call(op, {data}, Attrs(attrs), {});
+  return Call(op, {data, score_threshold}, Attrs(attrs), {});
 }
 
 
TVM_REGISTER_GLOBAL("relay.op.vision._make.get_valid_counts").set_body_typed(MakeGetValidCounts);
@@ -64,8 +64,9 @@ RELAY_REGISTER_OP("vision.get_valid_counts")
 a score threshold. Also moves valid boxes to the top of
 input data.
 )doc" TVM_ADD_FILELINE)
-    .set_num_inputs(1)
+    .set_num_inputs(2)
     .add_argument("data", "Tensor", "Input data.")
+    .add_argument("score_threshold", "Tensor", "Minimum Score.")
     .set_support_level(5)
     .add_type_rel("GetValidCount", GetValidCountRel);
 
@@ -73,9 +74,11 @@ TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs);
 
 bool NMSRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
             const TypeReporter& reporter) {
-  ICHECK_EQ(types.size(), 5);
+  ICHECK_EQ(types.size(), 6);
   const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
   const auto* valid_count = types[1].as<TensorTypeNode>();
+  if (valid_count == nullptr) return false;
   const NonMaximumSuppressionAttrs* param = 
attrs.as<NonMaximumSuppressionAttrs>();
   const auto& dshape = data->shape;
   const auto& vshape = valid_count->shape;
@@ -90,18 +93,17 @@ bool NMSRel(const Array<Type>& types, int num_inputs, const 
Attrs& attrs,
     fields.push_back(TensorType(oshape, DataType::Int(32)));
     std::vector<IndexExpr> countshape({dshape[0], 1});
     fields.push_back(TensorType(countshape, DataType::Int(32)));
-    reporter->Assign(types[4], TupleType(Array<Type>(fields)));
+    reporter->Assign(types[5], TupleType(Array<Type>(fields)));
   } else {
-    reporter->Assign(types[4], TensorType(dshape, data->dtype));
+    reporter->Assign(types[5], TensorType(dshape, data->dtype));
   }
   return true;
 }
 
-Expr MakeNMS(Expr data, Expr valid_count, Expr indices, Expr max_output_size, 
double iou_threshold,
+Expr MakeNMS(Expr data, Expr valid_count, Expr indices, Expr max_output_size, 
Expr iou_threshold,
              bool force_suppress, int top_k, int coord_start, int score_index, 
int id_index,
              bool return_indices, bool invalid_to_bottom) {
   auto attrs = make_object<NonMaximumSuppressionAttrs>();
-  attrs->iou_threshold = iou_threshold;
   attrs->force_suppress = force_suppress;
   attrs->top_k = top_k;
   attrs->coord_start = coord_start;
@@ -110,7 +112,7 @@ Expr MakeNMS(Expr data, Expr valid_count, Expr indices, 
Expr max_output_size, do
   attrs->return_indices = return_indices;
   attrs->invalid_to_bottom = invalid_to_bottom;
   static const Op& op = Op::Get("vision.non_max_suppression");
-  return Call(op, {data, valid_count, indices, max_output_size}, Attrs(attrs), 
{});
+  return Call(op, {data, valid_count, indices, max_output_size, 
iou_threshold}, Attrs(attrs), {});
 }
 
 
TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression").set_body_typed(MakeNMS);
@@ -121,11 +123,12 @@ be in the format of [class_id, score, left, top, right, 
bottom]
 or [score, left, top, right, bottom]. Set id_index to be -1 to
 ignore class_id axis.
 )doc" TVM_ADD_FILELINE)
-    .set_num_inputs(4)
+    .set_num_inputs(5)
     .add_argument("data", "Tensor", "Input data.")
     .add_argument("valid_count", "Tensor", "Number of valid anchor boxes.")
     .add_argument("indices", "Tensor", "Corresponding indices in original 
input tensor.")
     .add_argument("max_output_size", "Tensor", "Max number of output valid 
boxes.")
+    .add_argument("iou_threshold", "Tensor", "Threshold for box overlap.")
     .set_support_level(5)
     .add_type_rel("NMS", NMSRel);
 
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index d7a07f7..bae50c9 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -53,10 +53,9 @@ def get_tvm_output_with_vm(
     mod, params = relay.frontend.from_onnx(
         graph_def, shape_dict, opset=opset, freeze_params=freeze_params
     )
-    if convert_to_static:
-        from tvm.relay import transform
 
-        mod = transform.DynamicToStatic()(mod)
+    if convert_to_static:
+        mod = relay.transform.DynamicToStatic()(mod)
 
     ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
     result = ex.evaluate()(*input_data)
@@ -2821,7 +2820,6 @@ def test_unsqueeze_constant():
 
 
 def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, 
auto_pad="NOTSET"):
-    print(x_shape, kernel_shape, strides, mode, pads, auto_pad)
     x_np = np.random.uniform(size=x_shape).astype("float32")
 
     if mode == "max":
@@ -3690,6 +3688,99 @@ def test_roi_align():
     verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=2, 
spatial_scale=1.0)
 
 
+# @tvm.testing.uses_gpu
+def test_non_max_suppression():
+    def verify_nms(
+        boxes, scores, max_ouput_boxes_per_class, iou_threshold, 
score_threshold, output_dims
+    ):
+        input_names = ["boxes", "scores", "max_output_boxes_per_class", 
"iou_threshold"]
+        input_nodes = [
+            helper.make_tensor_value_info("boxes", TensorProto.FLOAT, 
boxes.shape),
+            helper.make_tensor_value_info("scores", TensorProto.FLOAT, 
scores.shape),
+            helper.make_tensor_value_info(
+                "max_output_boxes_per_class", TensorProto.INT64, 
max_output_boxes_per_class.shape
+            ),
+            helper.make_tensor_value_info("iou_threshold", TensorProto.FLOAT, 
iou_threshold.shape),
+        ]
+        inputs = [boxes, scores, max_output_boxes_per_class, iou_threshold]
+        if score_threshold is not None:
+            input_names.append("score_threshold")
+            input_nodes.append(
+                helper.make_tensor_value_info(
+                    "score_threshold", TensorProto.FLOAT, score_threshold.shape
+                )
+            )
+            inputs.append(score_threshold)
+        node = helper.make_node(
+            "NonMaxSuppression",
+            inputs=input_names,
+            outputs=["Y"],
+            center_point_box=0,
+        )
+
+        graph = helper.make_graph(
+            [node],
+            "nms_test",
+            inputs=input_nodes,
+            outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, 
output_dims)],
+        )
+
+        model = helper.make_model(graph, producer_name="nms_test")
+
+        verify_with_ort_with_inputs(model, inputs, use_vm=True)
+
+    boxes = np.array(
+        [
+            [
+                [0.0, 0.0, 0.3, 0.3],
+                [0.0, 0.0, 0.4, 0.4],
+                [0.0, 0.0, 0.5, 0.5],
+                [0.5, 0.5, 0.9, 0.9],
+                [0.5, 0.5, 1.0, 1.0],
+            ],
+            [
+                [0.0, 0.0, 0.3, 0.3],
+                [0.0, 0.0, 0.4, 0.4],
+                [0.5, 0.5, 0.95, 0.95],
+                [0.5, 0.5, 0.96, 0.96],
+                [0.5, 0.5, 1.0, 1.0],
+            ],
+        ]
+    ).astype("float32")
+
+    scores = np.array(
+        [
+            [[0.1, 0.2, 0.6, 0.3, 0.9], [0.1, 0.2, 0.6, 0.3, 0.9]],
+            [[0.1, 0.2, 0.6, 0.3, 0.9], [0.1, 0.2, 0.6, 0.3, 0.9]],
+        ]
+    ).astype("float32")
+    max_output_boxes_per_class = np.array(2).astype("int64")
+    iou_threshold = np.array(0.8).astype("float32")
+    output_dims = [8, 3]
+    verify_nms(boxes, scores, max_output_boxes_per_class, iou_threshold, None, 
output_dims)
+
+    boxes = np.array(
+        [
+            [
+                [0.0, 0.0, 1.0, 1.0],
+                [0.0, 0.1, 1.0, 1.1],
+                [0.0, -0.1, 1.0, 0.9],
+                [0.0, 10.0, 1.0, 11.0],
+                [0.0, 10.1, 1.0, 11.1],
+                [0.0, 100.0, 1.0, 101.0],
+            ]
+        ]
+    ).astype(np.float32)
+    scores = np.array([[[0.9, 0.75, 0.6, 0.95, 0.5, 0.3]]]).astype(np.float32)
+    max_output_boxes_per_class = np.array([3]).astype(np.int64)
+    iou_threshold = np.array([0.5]).astype(np.float32)
+    score_threshold = np.array([0.4]).astype(np.float32)
+    output_dims = [2, 3]
+    verify_nms(
+        boxes, scores, max_output_boxes_per_class, iou_threshold, 
score_threshold, output_dims
+    )
+
+
 def verify_cond_loop():
     y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [1])
     y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [1])
diff --git a/tests/python/relay/test_op_level5.py 
b/tests/python/relay/test_op_level5.py
index b3b6553..1ce8a18 100644
--- a/tests/python/relay/test_op_level5.py
+++ b/tests/python/relay/test_op_level5.py
@@ -314,8 +314,8 @@ def test_get_valid_counts():
             intrp = relay.create_executor("debug", ctx=ctx, target=target)
             out = intrp.evaluate(func)(np_data)
             tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, 
atol=1e-04)
-            # get_valid_count for cuda, opencl doesn't do data rearrangement
-            if target in ["cuda", "opencl"]:
+            # get_valid_count for opencl doesn't do data rearrangement
+            if target in ["opencl"]:
                 return
             tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, 
atol=1e-04)
             tvm.testing.assert_allclose(out[2].asnumpy(), np_out3, rtol=1e-3, 
atol=1e-04)
diff --git a/tests/python/relay/test_pass_lambda_lift.py 
b/tests/python/relay/test_pass_lambda_lift.py
index b19aebd..ce737b7 100644
--- a/tests/python/relay/test_pass_lambda_lift.py
+++ b/tests/python/relay/test_pass_lambda_lift.py
@@ -34,6 +34,7 @@ def test_basic():
     level1_func = relay.Function([x1, y1], level2_func(x1, y1))
 
     mod["main"] = level1_func
+    mod = relay.transform.InferType()(mod)
     new_mod = transform.LambdaLift()(mod)
     assert len(new_mod.functions) == 2
 
@@ -48,6 +49,7 @@ def test_closure():
     clo = outer_func(relay.ones(shape=(2,), dtype="float32"))
     mod["main"] = relay.Function([], relay.Call(clo, [relay.zeros(shape=(2,), 
dtype="float32")]))
 
+    mod = relay.transform.InferType()(mod)
     new_mod = transform.LambdaLift()(mod)
     assert len(new_mod.functions) == 3
 
@@ -75,6 +77,7 @@ def test_recursive():
     )
     mod["main"] = relay.Function([x], ret)
 
+    mod = relay.transform.InferType()(mod)
     new_mod = transform.LambdaLift()(mod)
     assert len(new_mod.functions) == 2
 

Reply via email to