mbarrett97 commented on a change in pull request #4543: [FRONTEND][TFLITE] Add
support for TFLite_Detection_PostProcess
URL: https://github.com/apache/incubator-tvm/pull/4543#discussion_r363254599
##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1494,6 +1499,112 @@ def convert_transpose_conv(self, op):
return out
+ def _convert_detection_postprocess(self, op):
+ """Convert TFLite_Detection_PostProcess"""
+ _option_names = [
+ "w_scale",
+ "max_detections",
+ "_output_quantized",
+ "detections_per_class",
+ "x_scale",
+ "nms_score_threshold",
+ "num_classes",
+ "max_classes_per_detection",
+ "use_regular_nms",
+ "y_scale",
+ "h_scale",
+ "_support_output_type_float_in_quantized_op",
+ "nms_iou_threshold"
+ ]
+
+ custom_options = get_custom_options(op, _option_names)
+ if custom_options["use_regular_nms"]:
+ raise tvm.error.OpAttributeUnImplemented(
+ "use_regular_nms=True is not yet supported for operator {}."
+ .format("TFLite_Detection_PostProcess")
+ )
+
+ inputs = self.get_input_tensors(op)
+ cls_pred = self.get_expr(inputs[1].tensor_idx)
+ loc_prob = self.get_expr(inputs[0].tensor_idx)
+ anchor_values = self.get_tensor_value(inputs[2])
+ anchor_boxes = len(anchor_values)
+ anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type())
+ anchor_expr = self.exp_tab.new_const(anchor_values, dtype=anchor_type)
+
+ if inputs[0].qnn_params:
+ loc_prob = _qnn.op.dequantize(data=loc_prob,
+
input_scale=inputs[0].qnn_params['scale'],
+
input_zero_point=inputs[0].qnn_params['zero_point'])
+ if inputs[1].qnn_params:
+ cls_pred = _qnn.op.dequantize(data=cls_pred,
+
input_scale=inputs[1].qnn_params['scale'],
+
input_zero_point=inputs[1].qnn_params['zero_point'])
+ if inputs[2].qnn_params:
+ anchor_expr = _qnn.op.dequantize(data=anchor_expr,
+
input_scale=inputs[2].qnn_params['scale'],
+
input_zero_point=inputs[2].qnn_params['zero_point'])
+
+ # reshape the cls_pred and loc_prob tensors so
+ # they can be consumed by multibox_transform_loc
+ cls_pred = _op.transpose(cls_pred, [0, 2, 1])
+ # loc_prob coords are in yxhw format
+ # need to convert to xywh
+ loc_coords = _op.split(loc_prob, 4, axis=2)
+ loc_prob = _op.concatenate(
+ [loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]],
axis=2
+ )
+ loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4])
+
+ # anchor coords are in yxhw format
+ # need to convert to ltrb
+ anchor_coords = _op.split(anchor_expr, 4, axis=1)
+ anchor_y = anchor_coords[0]
+ anchor_x = anchor_coords[1]
+ anchor_h = anchor_coords[2]
+ anchor_w = anchor_coords[3]
+ plus_half = _expr.const(0.5, dtype='float32')
+ minus_half = _expr.const(-0.5, dtype='float32')
+ anchor_l = _op.add(anchor_x, _op.multiply(anchor_w, minus_half))
+ anchor_r = _op.add(anchor_x, _op.multiply(anchor_w, plus_half))
+ anchor_t = _op.add(anchor_y, _op.multiply(anchor_h, minus_half))
+ anchor_b = _op.add(anchor_y, _op.multiply(anchor_h, plus_half))
+ anchor_expr = _op.concatenate([anchor_l, anchor_t, anchor_r,
anchor_b], axis=1)
+ anchor_expr = _op.expand_dims(anchor_expr, 0)
+
+ # attributes for multibox_transform_loc
+ new_attrs0 = {}
+ new_attrs0["clip"] = False
+ new_attrs0["threshold"] = custom_options["nms_score_threshold"]
+ new_attrs0["variances"] = (
+ 1/custom_options["x_scale"],
+ 1/custom_options["y_scale"],
+ 1/custom_options["w_scale"],
+ 1/custom_options["h_scale"],
+ )
+
+ # attributes for non_max_suppression
+ new_attrs1 = {}
+ new_attrs1["return_indices"] = False
Review comment:
The output from tflite always has dynamic shape, however as we're using the
graph runtime the tvm output is necessarily fixed in shape. In practice this
means the tvm version will always output a tensor big enough to contain the
maximal number of detections and only the first 'n' elements of the tensor will
be valid. The value of 'n' is also an output of the network (for both tflite
and tvm).
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services