yongwww commented on a change in pull request #4543: [FRONTEND][TFLITE] Add 
support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r363440816
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -1494,6 +1499,112 @@ def convert_transpose_conv(self, op):
 
         return out
 
+    def _convert_detection_postprocess(self, op):
+        """Convert TFLite_Detection_PostProcess"""
+        _option_names = [
+            "w_scale",
+            "max_detections",
+            "_output_quantized",
+            "detections_per_class",
+            "x_scale",
+            "nms_score_threshold",
+            "num_classes",
+            "max_classes_per_detection",
+            "use_regular_nms",
+            "y_scale",
+            "h_scale",
+            "_support_output_type_float_in_quantized_op",
+            "nms_iou_threshold"
+        ]
+
+        custom_options = get_custom_options(op, _option_names)
+        if custom_options["use_regular_nms"]:
+            raise tvm.error.OpAttributeUnImplemented(
+                "use_regular_nms=True is not yet supported for operator {}."
+                .format("TFLite_Detection_PostProcess")
+            )
+
+        inputs = self.get_input_tensors(op)
+        cls_pred = self.get_expr(inputs[1].tensor_idx)
+        loc_prob = self.get_expr(inputs[0].tensor_idx)
+        anchor_values = self.get_tensor_value(inputs[2])
+        anchor_boxes = len(anchor_values)
+        anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type())
+        anchor_expr = self.exp_tab.new_const(anchor_values, dtype=anchor_type)
+
+        if inputs[0].qnn_params:
+            loc_prob = _qnn.op.dequantize(data=loc_prob,
+                                          
input_scale=inputs[0].qnn_params['scale'],
+                                          
input_zero_point=inputs[0].qnn_params['zero_point'])
+        if inputs[1].qnn_params:
+            cls_pred = _qnn.op.dequantize(data=cls_pred,
+                                          
input_scale=inputs[1].qnn_params['scale'],
+                                          
input_zero_point=inputs[1].qnn_params['zero_point'])
+        if inputs[2].qnn_params:
+            anchor_expr = _qnn.op.dequantize(data=anchor_expr,
+                                             
input_scale=inputs[2].qnn_params['scale'],
+                                             
input_zero_point=inputs[2].qnn_params['zero_point'])
+
+        # reshape the cls_pred and loc_prob tensors so
+        # they can be consumed by multibox_transform_loc
+        cls_pred = _op.transpose(cls_pred, [0, 2, 1])
+        # loc_prob coords are in yxhw format
+        # need to convert to xywh
+        loc_coords = _op.split(loc_prob, 4, axis=2)
+        loc_prob = _op.concatenate(
+            [loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], 
axis=2
+        )
+        loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4])
+
+        # anchor coords are in yxhw format
+        # need to convert to ltrb
+        anchor_coords = _op.split(anchor_expr, 4, axis=1)
+        anchor_y = anchor_coords[0]
+        anchor_x = anchor_coords[1]
+        anchor_h = anchor_coords[2]
+        anchor_w = anchor_coords[3]
+        plus_half = _expr.const(0.5, dtype='float32')
+        minus_half = _expr.const(-0.5, dtype='float32')
+        anchor_l = _op.add(anchor_x, _op.multiply(anchor_w, minus_half))
+        anchor_r = _op.add(anchor_x, _op.multiply(anchor_w, plus_half))
+        anchor_t = _op.add(anchor_y, _op.multiply(anchor_h, minus_half))
+        anchor_b = _op.add(anchor_y, _op.multiply(anchor_h, plus_half))
+        anchor_expr = _op.concatenate([anchor_l, anchor_t, anchor_r, 
anchor_b], axis=1)
+        anchor_expr = _op.expand_dims(anchor_expr, 0)
+
+        # attributes for multibox_transform_loc
+        new_attrs0 = {}
+        new_attrs0["clip"] = False
+        new_attrs0["threshold"] = custom_options["nms_score_threshold"]
+        new_attrs0["variances"] = (
+            1/custom_options["x_scale"],
+            1/custom_options["y_scale"],
+            1/custom_options["w_scale"],
+            1/custom_options["h_scale"],
+        )
+
+        # attributes for non_max_suppression
+        new_attrs1 = {}
+        new_attrs1["return_indices"] = False
 
 Review comment:
   Thanks. cool, then relay VM could be used to support dynamism. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to