This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 60cf692a63 [Frontend][TFLite] fix detection_postprocess's
non_max_suppression_attrs["force_suppress"] (#12593)
60cf692a63 is described below
commit 60cf692a63a22cd2698273c4945f037b4b22474b
Author: czh978 <[email protected]>
AuthorDate: Mon Sep 19 13:49:04 2022 +0800
[Frontend][TFLite] fix detection_postprocess's
non_max_suppression_attrs["force_suppress"] (#12593)
* [Frontend][TFLite]fix detection_postprocess's
non_max_suppression_attrs["force_suppress"]
Since tvm only supports operators detection_postprocess use_regular_nms
is false, which will suppress boxes that exceed the threshold regardless
of the class when implementing NMS in tflite, in order for the results
of tvm and tflite to be consistent, we need to set force_suppress to
True.
* [Frontend][TFLite]fix detection_postprocess's
non_max_suppression_attrs[force_suppress]
Added a test case that reproduces inconsistent results between tvm and
tflite
When the force_suppress is false,it will get a good result if you set the
force_suppress as true
---
python/tvm/relay/frontend/tflite.py | 2 +-
tests/python/frontend/tflite/test_forward.py | 37 +++++++++++++++++++---------
2 files changed, 27 insertions(+), 12 deletions(-)
diff --git a/python/tvm/relay/frontend/tflite.py
b/python/tvm/relay/frontend/tflite.py
index 6c68230e0e..a7e10ad72e 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -3355,7 +3355,7 @@ class OperatorConverter(object):
non_max_suppression_attrs = {}
non_max_suppression_attrs["return_indices"] = False
non_max_suppression_attrs["iou_threshold"] =
custom_options["nms_iou_threshold"]
- non_max_suppression_attrs["force_suppress"] = False
+ non_max_suppression_attrs["force_suppress"] = True
non_max_suppression_attrs["top_k"] = anchor_boxes
non_max_suppression_attrs["max_output_size"] =
custom_options["max_detections"]
non_max_suppression_attrs["invalid_to_bottom"] = False
diff --git a/tests/python/frontend/tflite/test_forward.py
b/tests/python/frontend/tflite/test_forward.py
index deaef72e1d..7b2bd60d8a 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -4311,13 +4311,8 @@ def test_forward_matrix_diag():
# ----------------
-def test_detection_postprocess():
- """Detection PostProcess"""
- tf_model_file = tf_testing.get_workload_official(
- "http://download.tensorflow.org/models/object_detection/"
- "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
- "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/tflite_graph.pb",
- )
+def _test_detection_postprocess(tf_model_file, box_encodings_size,
class_predictions_size):
+ """One iteration of detection postProcess with given model and shapes"""
converter = tf.lite.TFLiteConverter.from_frozen_graph(
tf_model_file,
input_arrays=["raw_outputs/box_encodings",
"raw_outputs/class_predictions"],
@@ -4328,16 +4323,16 @@ def test_detection_postprocess():
"TFLite_Detection_PostProcess:3",
],
input_shapes={
- "raw_outputs/box_encodings": (1, 1917, 4),
- "raw_outputs/class_predictions": (1, 1917, 91),
+ "raw_outputs/box_encodings": box_encodings_size,
+ "raw_outputs/class_predictions": class_predictions_size,
},
)
converter.allow_custom_ops = True
converter.inference_type = tf.lite.constants.FLOAT
tflite_model = converter.convert()
np.random.seed(0)
- box_encodings = np.random.uniform(size=(1, 1917, 4)).astype("float32")
- class_predictions = np.random.uniform(size=(1, 1917, 91)).astype("float32")
+ box_encodings =
np.random.uniform(size=box_encodings_size).astype("float32")
+ class_predictions =
np.random.uniform(size=class_predictions_size).astype("float32")
tflite_output = run_tflite_graph(tflite_model, [box_encodings,
class_predictions])
tvm_output = run_tvm_graph(
tflite_model,
@@ -4382,6 +4377,26 @@ def test_detection_postprocess():
)
+def test_detection_postprocess():
+ """Detection PostProcess"""
+ box_encodings_size = (1, 1917, 4)
+ class_predictions_size = (1, 1917, 91)
+ tf_model_file = tf_testing.get_workload_official(
+ "http://download.tensorflow.org/models/object_detection/"
+ "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
+ "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/tflite_graph.pb",
+ )
+ _test_detection_postprocess(tf_model_file, box_encodings_size,
class_predictions_size)
+
+ box_encodings_size = (1, 2034, 4)
+ class_predictions_size = (1, 2034, 91)
+ tf_model_file = download_testdata(
+
"https://github.com/czh978/models_for_tvm_test/raw/main/tflite_graph_with_postprocess.pb",
+ "tflite_graph_with_postprocess.pb",
+ )
+ _test_detection_postprocess(tf_model_file, box_encodings_size,
class_predictions_size)
+
+
#######################################################################
# Custom Converter
# ----------------