FrozenGene commented on a change in pull request #4440: [TFLite] Add
transpose_conv to TFLite parser
URL: https://github.com/apache/incubator-tvm/pull/4440#discussion_r351583281
##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1370,6 +1371,85 @@ def convert_prelu(self, op):
return out
+ def convert_transpose_conv(self, op):
+ """Convert TFLite TRANSPOSE_CONV"""
+ try:
+ from tflite.BuiltinOptions import BuiltinOptions
+ from tflite.TensorType import TensorType
+ from tflite.Operator import Operator
+ from tflite.TransposeConvOptions import TransposeConvOptions
+ from tflite.Padding import Padding
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
+ assert isinstance(op, Operator)
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 3, "input tensors length should be 3"
+
+ # Input (data) Tensor. NHWC layout
+ input_tensor = input_tensors[2]
+ _, _, _, input_c = input_tensor.tensor.ShapeAsNumpy()
+ # Weights tensor. TFLite uses OHWI layout
+ weights_tensor = input_tensors[1]
+ out_channels, kernel_h, kernel_w, in_channels =
weights_tensor.tensor.ShapeAsNumpy()
+ assert input_c == in_channels, \
+ "Input channel in the filter should match to channel in the input"
+ # output_shape Tensor. NHWC layout
+ output_shape_tensor = input_tensors[0]
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1, "output tensors length should be 1"
+ output_tensor = output_tensors[0]
+ output_tensor_type = output_tensor.tensor.Type()
+ output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)
+
+ assert op.BuiltinOptionsType() == BuiltinOptions.TransposeConvOptions
+ op_options = op.BuiltinOptions()
+ deconv_options = TransposeConvOptions()
+ deconv_options.Init(op_options.Bytes, op_options.Pos)
+
+ padding = deconv_options.Padding()
+ stride_h = deconv_options.StrideH()
+ stride_w = deconv_options.StrideW()
+ assert padding in (Padding.VALID, Padding.SAME), \
+ 'Padding format {} is not supported for operator
TRANSPOSE_CONV'.format(padding)
+
+ # Data
+ in_expr = self.get_expr(input_tensor.tensor_idx)
+
+ # Weights
+ weights_tensor_type = weights_tensor.tensor.Type()
+ # weights tensor type should be UINT8 (quantization) or FLOAT32
+ assert weights_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
+ weight_tensor_type_str = self.get_tensor_type_str(weights_tensor_type)
+ weight_value_ohwi = self.get_tensor_value(weights_tensor)
+ # Relay kernel_layout should be OIHW
+ # Relay weights layout should be different from kernel_layout - it
should be IOHW
+ weight_value_iohw = np.transpose(weight_value_ohwi, (3, 0, 1, 2))
+ weight_expr_iohw = self.exp_tab.new_const(weight_value_iohw,
dtype=weight_tensor_type_str)
+
+ # Output shape value
+ output_shape_value = self.get_tensor_value(output_shape_tensor)
+ # Relay expects filter output channel to match to output tensor
channel.
+ assert out_channels == output_shape_value[3], \
+ "Output channel in the filter should match to channel in the
output_shape"
+
+ # TF frontend supports 'SAME' padding for kernel 1x1 only. Lets do the
same here
+ if padding == Padding.SAME:
+ assert (kernel_h, kernel_w) == (1, 1), \
+ "SAME padding is supported for kernel (1,1) only"
Review comment:
if kh kw is 3x3, what is the current error msg?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services