yongwww commented on a change in pull request #4312:
URL: https://github.com/apache/incubator-tvm/pull/4312#discussion_r432886691



##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -614,6 +614,52 @@ def _impl(inputs, attr, params, mod):
         return out
     return _impl
 
+def _nms():
+    def _impl(inputs, attr, params, mod):
+        # Get parameter values
+        max_output_size = 
int(np.atleast_1d(inputs[2].data.asnumpy().astype("int64"))[0])
+        iou_threshold = np.atleast_1d(inputs[3].data.asnumpy())[0]
+        # score_threshold was introduced from V3
+        score_threshold = np.atleast_1d(inputs[4].data.asnumpy())[0] if 
len(inputs) > 4 else 0.0
+
+        # Generate data with shape (1, num_anchors, 5)
+        scores = AttrCvt(op_name="expand_dims",
+                         ignores=['T_threshold'],
+                         extras={'axis': -1, 'num_newaxis': 1})([inputs[1]], 
attr)
+        data = get_relay_op('concatenate')([scores, inputs[0]], -1)
+        data = get_relay_op('expand_dims')(data, 0, 1)
+
+        # reason why using get_valid_counts is for inference performance
+        ct, data, indices = get_relay_op('get_valid_counts')(data,
+                                                             
score_threshold=score_threshold,
+                                                             id_index=-1,
+                                                             score_index=0)
+        # TensorFlow NMS doesn't have parameter top_k
+        top_k = -1
+        # TF doesn't have class id for nms input
+        score_index = 0
+        nms_ret = get_relay_op('non_max_suppression')(data=data,
+                                                      valid_count=ct,
+                                                      indices=indices,
+                                                      
max_output_size=max_output_size,
+                                                      
iou_threshold=iou_threshold,
+                                                      force_suppress=True,
+                                                      top_k=top_k,
+                                                      coord_start=1,
+                                                      score_index=score_index,
+                                                      id_index=-1,
+                                                      return_indices=True,
+                                                      invalid_to_bottom=False)
+
+        # squeeze it, TF NMS is not batched
+        end = get_relay_op("squeeze")(nms_ret[1], axis=[1])
+        data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0])
+
+        # slice to get the dynamic result
+        ret = get_relay_op("strided_slice")(data_slice, _expr.const([0]), end, 
_expr.const([1]))
+        return ret
+    return _impl

Review comment:
       Updated. Also used slice_mode for tf `Slice`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to