This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 60cf692a63 [Frontend][TFLite] fix detection_postprocess's 
non_max_suppression_attrs["force_suppress"] (#12593)
60cf692a63 is described below

commit 60cf692a63a22cd2698273c4945f037b4b22474b
Author: czh978 <[email protected]>
AuthorDate: Mon Sep 19 13:49:04 2022 +0800

    [Frontend][TFLite] fix detection_postprocess's 
non_max_suppression_attrs["force_suppress"] (#12593)
    
    * [Frontend][TFLite]fix detection_postprocess's 
non_max_suppression_attrs["force_suppress"]
    
    Since tvm only supports operators detection_postprocess use_regular_nms
    is false, which will suppress boxes that exceed the threshold regardless
    of the class when implementing NMS in tflite, in order for the results
    of tvm and tflite to be consistent, we need to set force_suppress to
    True.
    
    * [Frontend][TFLite]fix detection_postprocess's 
non_max_suppression_attrs[force_suppress]
    
    Added a test case that reproduces inconsistent results between tvm and 
tflite
    When the force_suppress is false,it will get a good result if you set the 
force_suppress as true
---
 python/tvm/relay/frontend/tflite.py          |  2 +-
 tests/python/frontend/tflite/test_forward.py | 37 +++++++++++++++++++---------
 2 files changed, 27 insertions(+), 12 deletions(-)

diff --git a/python/tvm/relay/frontend/tflite.py 
b/python/tvm/relay/frontend/tflite.py
index 6c68230e0e..a7e10ad72e 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -3355,7 +3355,7 @@ class OperatorConverter(object):
         non_max_suppression_attrs = {}
         non_max_suppression_attrs["return_indices"] = False
         non_max_suppression_attrs["iou_threshold"] = 
custom_options["nms_iou_threshold"]
-        non_max_suppression_attrs["force_suppress"] = False
+        non_max_suppression_attrs["force_suppress"] = True
         non_max_suppression_attrs["top_k"] = anchor_boxes
         non_max_suppression_attrs["max_output_size"] = 
custom_options["max_detections"]
         non_max_suppression_attrs["invalid_to_bottom"] = False
diff --git a/tests/python/frontend/tflite/test_forward.py 
b/tests/python/frontend/tflite/test_forward.py
index deaef72e1d..7b2bd60d8a 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -4311,13 +4311,8 @@ def test_forward_matrix_diag():
 # ----------------
 
 
-def test_detection_postprocess():
-    """Detection PostProcess"""
-    tf_model_file = tf_testing.get_workload_official(
-        "http://download.tensorflow.org/models/object_detection/";
-        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
-        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/tflite_graph.pb",
-    )
+def _test_detection_postprocess(tf_model_file, box_encodings_size, 
class_predictions_size):
+    """One iteration of detection postProcess with given model and shapes"""
     converter = tf.lite.TFLiteConverter.from_frozen_graph(
         tf_model_file,
         input_arrays=["raw_outputs/box_encodings", 
"raw_outputs/class_predictions"],
@@ -4328,16 +4323,16 @@ def test_detection_postprocess():
             "TFLite_Detection_PostProcess:3",
         ],
         input_shapes={
-            "raw_outputs/box_encodings": (1, 1917, 4),
-            "raw_outputs/class_predictions": (1, 1917, 91),
+            "raw_outputs/box_encodings": box_encodings_size,
+            "raw_outputs/class_predictions": class_predictions_size,
         },
     )
     converter.allow_custom_ops = True
     converter.inference_type = tf.lite.constants.FLOAT
     tflite_model = converter.convert()
     np.random.seed(0)
-    box_encodings = np.random.uniform(size=(1, 1917, 4)).astype("float32")
-    class_predictions = np.random.uniform(size=(1, 1917, 91)).astype("float32")
+    box_encodings = 
np.random.uniform(size=box_encodings_size).astype("float32")
+    class_predictions = 
np.random.uniform(size=class_predictions_size).astype("float32")
     tflite_output = run_tflite_graph(tflite_model, [box_encodings, 
class_predictions])
     tvm_output = run_tvm_graph(
         tflite_model,
@@ -4382,6 +4377,26 @@ def test_detection_postprocess():
         )
 
 
+def test_detection_postprocess():
+    """Detection PostProcess"""
+    box_encodings_size = (1, 1917, 4)
+    class_predictions_size = (1, 1917, 91)
+    tf_model_file = tf_testing.get_workload_official(
+        "http://download.tensorflow.org/models/object_detection/";
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
+        "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/tflite_graph.pb",
+    )
+    _test_detection_postprocess(tf_model_file, box_encodings_size, 
class_predictions_size)
+
+    box_encodings_size = (1, 2034, 4)
+    class_predictions_size = (1, 2034, 91)
+    tf_model_file = download_testdata(
+        
"https://github.com/czh978/models_for_tvm_test/raw/main/tflite_graph_with_postprocess.pb";,
+        "tflite_graph_with_postprocess.pb",
+    )
+    _test_detection_postprocess(tf_model_file, box_encodings_size, 
class_predictions_size)
+
+
 #######################################################################
 # Custom Converter
 # ----------------

Reply via email to