FrozenGene commented on a change in pull request #4440: [TFLite] Add 
transpose_conv to TFLite parser
URL: https://github.com/apache/incubator-tvm/pull/4440#discussion_r351587309
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -1370,6 +1371,85 @@ def convert_prelu(self, op):
 
         return out
 
+    def convert_transpose_conv(self, op):
+        """Convert TFLite TRANSPOSE_CONV"""
+        try:
+            from tflite.BuiltinOptions import BuiltinOptions
+            from tflite.TensorType import TensorType
+            from tflite.Operator import Operator
+            from tflite.TransposeConvOptions import TransposeConvOptions
+            from tflite.Padding import Padding
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        assert isinstance(op, Operator)
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 3, "input tensors length should be 3"
+
+        # Input (data) Tensor. NHWC layout
+        input_tensor = input_tensors[2]
+        _, _, _, input_c = input_tensor.tensor.ShapeAsNumpy()
+        # Weights tensor. TFLite uses OHWI layout
+        weights_tensor = input_tensors[1]
+        out_channels, kernel_h, kernel_w, in_channels = 
weights_tensor.tensor.ShapeAsNumpy()
+        assert input_c == in_channels, \
+            "Input channel in the filter should match to channel in the input"
+        # output_shape Tensor. NHWC layout
+        output_shape_tensor = input_tensors[0]
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "output tensors length should be 1"
+        output_tensor = output_tensors[0]
+        output_tensor_type = output_tensor.tensor.Type()
+        output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)
+
+        assert op.BuiltinOptionsType() == BuiltinOptions.TransposeConvOptions
+        op_options = op.BuiltinOptions()
+        deconv_options = TransposeConvOptions()
+        deconv_options.Init(op_options.Bytes, op_options.Pos)
+
+        padding = deconv_options.Padding()
+        stride_h = deconv_options.StrideH()
+        stride_w = deconv_options.StrideW()
+        assert padding in (Padding.VALID, Padding.SAME), \
+            'Padding format {} is not supported for operator 
TRANSPOSE_CONV'.format(padding)
+
+        # Data
+        in_expr = self.get_expr(input_tensor.tensor_idx)
+
+        # Weights
+        weights_tensor_type = weights_tensor.tensor.Type()
+        # weights tensor type should be UINT8 (quantization) or FLOAT32
+        assert weights_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
+        weight_tensor_type_str = self.get_tensor_type_str(weights_tensor_type)
+        weight_value_ohwi = self.get_tensor_value(weights_tensor)
+        # Relay kernel_layout should be OIHW
+        # Relay weights layout should be different from kernel_layout - it 
should be IOHW
+        weight_value_iohw = np.transpose(weight_value_ohwi, (3, 0, 1, 2))
+        weight_expr_iohw = self.exp_tab.new_const(weight_value_iohw, 
dtype=weight_tensor_type_str)
+
+        # Output shape value
+        output_shape_value = self.get_tensor_value(output_shape_tensor)
+        # Relay expects filter output channel to match to output tensor 
channel.
+        assert out_channels == output_shape_value[3], \
+            "Output channel in the filter should match to channel in the 
output_shape"
+
+        # TF frontend supports 'SAME' padding for kernel 1x1 only. Lets do the 
same here
+        if padding == Padding.SAME:
+            assert (kernel_h, kernel_w) == (1, 1), \
+                "SAME padding is supported for kernel (1,1) only"
 
 Review comment:
   Could we support non-1x1 conv_transpose too? I think maybe it is a good time 
to do it completely no matter tf or tflite.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to