u99127 commented on a change in pull request #5848:
URL: https://github.com/apache/incubator-tvm/pull/5848#discussion_r448548518
##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -651,12 +701,43 @@ def convert_shape(self, op):
def convert_relu(self, op):
"""Convert TFLite ReLU"""
+ try:
+ from tflite.ActivationFunctionType import ActivationFunctionType
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
Review comment:
I think this is unnecessary given the import of ActivationFunctionType
in the constructor
[here](https://github.com/apache/incubator-tvm/blob/b979bf6a7630ada055fff2d65e1cd0f8d55bb6a0/python/tvm/relay/frontend/tflite.py#L54)
##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -692,6 +773,11 @@ def _hard_swish(data):
def convert_relu6(self, op):
"""Convert TFLite ReLU6"""
+ try:
+ from tflite.ActivationFunctionType import ActivationFunctionType
Review comment:
Same as relu, I think this is unnecessary given the import of
ActivationFunctionType in the constructor
[here](https://github.com/apache/incubator-tvm/blob/b979bf6a7630ada055fff2d65e1cd0f8d55bb6a0/python/tvm/relay/frontend/tflite.py#L54)
##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -263,21 +305,29 @@ def get_tensor_value(self, tensor_wrapper):
except ImportError:
raise ImportError("The tflite package must be installed")
+ # Read the data from the buffer. Also extract the shape.
+ # The shape is used later to reshape the data.
+ data = tensor_wrapper.buffer.DataAsNumpy()
+ shape = tensor_wrapper.tensor.ShapeAsNumpy()
+
+ # When TFLite buffer is of size 1 (scalar), then TFLite tensor shape
is set to 0.
+ # Therefore, we set the shape to 1 for numpy reshape to work. Set
shape to 1 if the data is
+ # a scalar type
+ if data.size == 1 and isinstance(shape, int) and shape == 0:
+ shape = (1,)
+
+ if tensor_wrapper.tensor.Type() == TensorType.INT8:
Review comment:
Minor nit and this should really be credited to Dmitriy Smirnov.
https://github.com/d-smirnov
the condition here could well be pulled out into a helper function that has
a dictionary to help us map from TensorType to numpy type ?
Would make the code much cleaner and reduce duplication.
i.e. something like
def get_tensor_type_as_numpy(self, tensor_wrapper):
"""Returns np.dtype out of TensorType"""
"""Returns np.dtype out of TensorType"""
assert isinstance(tensor_wrapper, TensorWrapper)
try:
from tflite.TensorType import TensorType
return { TensorType.UINT8: np.uint8,
TensorType.INT8: np.int8,
TensorType.FLOAT32: np.float32,
TensorType.INT32: np.int32,
TensorType.INT64: np.int64,
TensorType.BOOL: np.bool_ }[
tensor_wrapper.tensor.Type() ]
except ImportError:
raise ImportError("The tflite package must be installed")
except KeyError:
raise NotImplementedError("Tensor type '{}' currently not
supported"
.format(tensor_wrapper.tensor.Type()))
def get_tensor_value(self, tensor_wrapper):
"""Get tensor buffer value from given tensor wrapper"""
assert isinstance(tensor_wrapper, TensorWrapper)
value_type = self.get_tensor_type_as_numpy( tensor_wrapper )
return np.frombuffer( tensor_wrapper.buffer.DataAsNumpy(),
dtype=value_type ).reshape(
tensor_wrapper.tensor.ShapeAsNumpy() \
if 0 != tensor_wrapper.tensor.ShapeLength() \
else [] )
##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -787,14 +897,30 @@ def test_forward_convolution():
_test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1],
'SAME', 'NHWC', quantized=quantized)
_test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2],
'VALID', 'NHWC', quantized=quantized)
- # depthwise convolution
- _test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME',
'NHWC', True)
- _test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID',
'NHWC', True)
- _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1],
'SAME', 'NHWC', True)
- _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID',
'NHWC', True)
- _test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID',
'NHWC', True)
- # dephtwise convolution with single input channel
- _test_convolution([1, 76, 64, 1], [9, 5, 1, 96], [1, 1], [1, 1], 'SAME',
'NHWC', True)
+ # depthwise convolution
+ _test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1],
'SAME', 'NHWC', True, quantized=quantized)
+ _test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2],
'VALID', 'NHWC', True, quantized=quantized)
+ _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1],
'SAME', 'NHWC', True, quantized=quantized)
+ _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2],
'VALID', 'NHWC', True, quantized=quantized)
+ _test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2],
'VALID', 'NHWC', True, quantized=quantized)
+ # dephtwise convolution with single input channel
Review comment:
dephtwise / depthwise.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]