This is an automated email from the ASF dual-hosted git repository.
zhaowu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 2845329 [TFLite] Implemented EXPAND_DIMS Operator for TFLite. (#6243)
2845329 is described below
commit 2845329009a42dcfbb3ac6d6dda7b578b8f8c585
Author: Rishabh Jain <[email protected]>
AuthorDate: Tue Aug 11 13:35:55 2020 +0530
[TFLite] Implemented EXPAND_DIMS Operator for TFLite. (#6243)
---
python/tvm/relay/frontend/tflite.py | 26 +++++++++++++
tests/python/frontend/tflite/test_forward.py | 56 ++++++++++++++++++++++++++++
2 files changed, 82 insertions(+)
diff --git a/python/tvm/relay/frontend/tflite.py
b/python/tvm/relay/frontend/tflite.py
index f168f1b..11d6576 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -84,6 +84,7 @@ class OperatorConverter(object):
'ELU': self.convert_elu,
'EQUAL': self.convert_equal,
'EXP': self.convert_exp,
+ 'EXPAND_DIMS': self.convert_expand_dims,
'FILL': self.convert_fill,
'FLOOR_DIV': self.convert_floor_div,
'FLOOR_MOD': self.convert_floor_mod,
@@ -2904,6 +2905,31 @@ class OperatorConverter(object):
ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores,
valid_count]), size=4)
return ret
+ def convert_expand_dims(self, op):
+ """Convert TFLite EXPAND_DIMS"""
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+
+ if input_tensors[0].qnn_params:
+ # Check that input and output tensor have same qnn params.
+ output_tensors = self.get_output_tensors(op)
+ assert self.has_same_qnn_params(input_tensors[0],
output_tensors[0]), \
+ "TFLite EXPAND_DIMS requires input and output tensors' \
+ scale and zero points to be equal"
+
+ input_expr = self.get_tensor_expr(input_tensors[0])
+ axis = self.get_tensor_value(input_tensors[1])
+ if isinstance(axis, np.ndarray):
+ assert len(axis) == 1, "only one value is expected."
+ axis = int(axis)
+
+ ndims = len(input_tensors[0].tensor.ShapeAsNumpy())
+ assert (-1-ndims <= axis <= ndims), "axis out of range"
+
+ out = _op.expand_dims(input_expr, axis, 1)
+
+ return out
+
def convert_one_hot(self, op):
"""Convert TFLite ONE_HOT"""
try:
diff --git a/tests/python/frontend/tflite/test_forward.py
b/tests/python/frontend/tflite/test_forward.py
index 2e57175..33ac6d4 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -2031,6 +2031,61 @@ def test_forward_padv2():
#######################################################################
+# EXPAND_DIMS
+# -----------
+
+def _test_expand_dims(input_shape, input_type, axis, quantized=False):
+ """ One iteration of EXPAND_DIMS """
+ with tf.Graph().as_default():
+ axis= ops.convert_to_tensor(axis, dtype=axis.dtype)
+
+ if quantized:
+ # ignoring input_type as quantized requires uint8
+ input = np.random.uniform(0, 256, input_shape).astype('uint8')
+ in_input = tf.placeholder(dtype='float32', shape=input.shape,
name="input")
+
+ input_range = {'q_input': (-100, 100)}
+ inq_input = tf.quantization.fake_quant_with_min_max_args(
+ in_input,
+ min=-100,
+ max=100,
+ name="q_input")
+
+ out = array_ops.expand_dims(inq_input, axis=axis)
+ out = tf.quantization.fake_quant_with_min_max_args(
+ out,
+ min=-100,
+ max=100,
+ name="out")
+
+ compare_tflite_with_tvm(
+ [input],
+ ["q_input"],
+ [inq_input],
+ [out],
+ quantized=True,
+ input_range=input_range)
+ else:
+ input = np.random.uniform(-100, 100,
input_shape).astype(input_type)
+ in_input = tf.placeholder(dtype=input.dtype, shape=input.shape,
name="input")
+
+ out = array_ops.expand_dims(in_input, axis=axis)
+
+ compare_tflite_with_tvm(
+ [input],
+ ["input"],
+ [in_input],
+ [out])
+
+def test_forward_expand_dims():
+ """ EXPAND_DIMS """
+ for quantized in [False, True]:
+ _test_expand_dims((6, 2, 7, 5), 'float32', np.int32(0),
quantized=quantized)
+ _test_expand_dims((1, 2, 3), 'int32', np.int32(-2),
quantized=quantized)
+ _test_expand_dims((2, 4, 5), 'float32', np.array([1], dtype=np.int32),
quantized=quantized)
+
+
+#######################################################################
# ONE_HOT
# -------
@@ -3021,6 +3076,7 @@ if __name__ == '__main__':
test_forward_select()
test_forward_quantize_dequantize()
test_forward_arg_min_max()
+ test_forward_expand_dims()
# NN
test_forward_convolution()