This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 83b310d5a4 [Frontend][TFLite] Add support for NonMaxSuppressionV5 op
(#12003)
83b310d5a4 is described below
commit 83b310d5a41b92a857c17d25a0a9b0546441586a
Author: Black <[email protected]>
AuthorDate: Tue Jul 5 14:17:16 2022 +0800
[Frontend][TFLite] Add support for NonMaxSuppressionV5 op (#12003)
* add nms_v5 op for TFLite
* add a test for the TFLite nms_v5 op
---
python/tvm/relay/frontend/tflite.py | 64 ++++++++++++++++++++++++++++
tests/python/frontend/tflite/test_forward.py | 40 +++++++++++++++++
2 files changed, 104 insertions(+)
diff --git a/python/tvm/relay/frontend/tflite.py
b/python/tvm/relay/frontend/tflite.py
index 2a9d66acff..d7ec441e0e 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -176,6 +176,7 @@ class OperatorConverter(object):
"UNIDIRECTIONAL_SEQUENCE_LSTM":
self.convert_unidirectional_sequence_lstm,
"WHERE": self.convert_select,
"ZEROS_LIKE": self.convert_zeros_like,
+ "NON_MAX_SUPPRESSION_V5": self.convert_nms_v5,
}
def check_unsupported_ops(self):
@@ -3347,6 +3348,69 @@ class OperatorConverter(object):
ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores,
valid_count]), size=4)
return ret
+ def convert_nms_v5(self, op):
+ """Convert TFLite NonMaxSuppressionV5"""
+ #
https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/non-max-suppression-v5
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 6, "input tensor length should be 6"
+ boxes = self.get_expr(input_tensors[0].tensor_idx)
+ scores = self.get_expr(input_tensors[1].tensor_idx)
+ max_output_size = self.get_tensor_value(input_tensors[2])
+ iou_threshold = self.get_tensor_value(input_tensors[3])
+ score_threshold = self.get_tensor_value(input_tensors[4])
+ soft_nms_sigma = self.get_tensor_value(input_tensors[5])
+
+ if isinstance(max_output_size, np.ndarray):
+ assert max_output_size.size == 1, "only one value is expected."
+ max_output_size = int(max_output_size)
+
+ if isinstance(iou_threshold, np.ndarray):
+ assert iou_threshold.size == 1, "only one value is expected."
+ iou_threshold = float(iou_threshold)
+
+ if isinstance(score_threshold, np.ndarray):
+ assert score_threshold.size == 1, "only one value is expected."
+ score_threshold = float(score_threshold)
+
+ if isinstance(soft_nms_sigma, np.ndarray):
+ assert soft_nms_sigma.size == 1, "only one value is expected."
+ soft_nms_sigma = float(soft_nms_sigma)
+ if soft_nms_sigma != 0.0:
+ raise tvm.error.OpNotImplemented(
+ "It is soft_nms when soft_nms_sigma != 0, which is not
supported!"
+ )
+
+ scores_expand = _op.expand_dims(scores, axis=-1, num_newaxis=1)
+ data = _op.concatenate([scores_expand, boxes], -1)
+ data = _op.expand_dims(data, axis=0, num_newaxis=1)
+
+ count, data, indices = _op.vision.get_valid_counts(
+ data, score_threshold=score_threshold, id_index=-1, score_index=0
+ )
+
+ nms_ret = _op.vision.non_max_suppression(
+ data=data,
+ valid_count=count,
+ indices=indices,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ force_suppress=True,
+ top_k=-1,
+ coord_start=1,
+ score_index=0,
+ id_index=-1,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+
+ selected_indices = _op.squeeze(nms_ret[0], axis=[0])
+ selected_indices = _op.strided_slice(selected_indices, [0],
[max_output_size])
+ valide_num = _op.squeeze(nms_ret[1], axis=[1])
+ selected_scores = _op.take(scores, selected_indices, axis=0)
+ out = _expr.TupleWrapper(_expr.Tuple([selected_indices,
selected_scores, valide_num]), 3)
+ return out
+
def convert_expand_dims(self, op):
"""Convert TFLite EXPAND_DIMS"""
input_tensors = self.get_input_tensors(op)
diff --git a/tests/python/frontend/tflite/test_forward.py
b/tests/python/frontend/tflite/test_forward.py
index 23b5a03ffb..c271a669e9 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -44,6 +44,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import image_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import variables
@@ -4937,6 +4938,42 @@ def test_prevent_tensorflow_dynamic_range():
tvm_output = run_tvm_graph(tflite_model, data_array,
data_in.name.replace(":0", ""))
+def _test_nms_v5(
+ bx_shape, score_shape, iou_threshold, score_threshold, max_output_size,
dtype="float32"
+):
+ """One iteration of nms_v5 with given attributes"""
+ boxes = np.random.uniform(0, 10, size=bx_shape).astype(dtype)
+ scores = np.random.uniform(size=score_shape).astype(dtype)
+
+ tf.reset_default_graph()
+ tf.compat.v1.disable_eager_execution()
+ in_data_1 = array_ops.placeholder(dtype, boxes.shape, name="in_data_1")
+ in_data_2 = array_ops.placeholder(dtype, scores.shape, name="in_data_2")
+ out = image_ops.non_max_suppression_with_scores(
+ boxes=in_data_1,
+ scores=in_data_2,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ name="nms",
+ )
+
+ compare_tflite_with_tvm(
+ [boxes, scores],
+ ["in_data_1:0", "in_data_2:0"],
+ [in_data_1, in_data_2],
+ [out[0], out[1]],
+ out_names=[out[0].name, out[1].name],
+ experimental_new_converter=True,
+ )
+
+
+def test_forward_nms_v5():
+ """test nms_v5"""
+ _test_nms_v5((10000, 4), (10000,), 0.5, 0.4, 100)
+ _test_nms_v5((1000, 4), (1000,), 0.7, 0.3, 50)
+
+
#######################################################################
# Main
# ----
@@ -5031,6 +5068,9 @@ if __name__ == "__main__":
# Detection_PostProcess
test_detection_postprocess()
+ # NonMaxSuppressionV5
+ test_forward_nms_v5()
+
# Overwrite Converter
test_custom_op_converter()