apivovarov commented on a change in pull request #4447:
[Relay][Frontend][TFlite] Add parses support for UNPACK tflite operator
URL: https://github.com/apache/incubator-tvm/pull/4447#discussion_r354039512
##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1238,6 +1239,54 @@ def convert_pack(self, op):
out = _op.concatenate(in_exprs_reshaped, pack_axis)
return out
+ def convert_unpack(self, op):
+ """Convert TFLite unpack"""
+ try:
+ from tflite.BuiltinOptions import BuiltinOptions
+ from tflite.Operator import Operator
+ from tflite.UnpackOptions import UnpackOptions
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
+ assert isinstance(op, Operator)
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 1, "input tensors length should be 1"
+ input_tensor = input_tensors[0]
+ in_expr = self.get_expr(input_tensor.tensor_idx)
+ assert op.BuiltinOptionsType() == BuiltinOptions.UnpackOptions
+ op_options = op.BuiltinOptions()
+ unpack_options = UnpackOptions()
+ unpack_options.Init(op_options.Bytes, op_options.Pos)
+ num_unpacks = unpack_options.Num()
+ unpack_axis = unpack_options.Axis()
+
+ # Relay doesn't support 'unpack' operator so we use 'split' &
'squeeze' instead.
+ # We have to do 'squeeze' along the split axis but Relay expects
+ # squeeze_axis to be either None or List
+ if unpack_axis == 0:
+ squeeze_axis = None
+ else:
+ squeeze_axis = [unpack_axis]
+
+ # Relay doesn't like TupleWrapper of 1 element so we isolate the case
of unpacking
+ # a tensor by an axis with len(axis) == 1. For reference see
convert_split()
+ # Such unpacking will result in the same tensor so we omit 'split' and
only squeeze
+ # along the axis of dim == 1
+ if num_unpacks == 1:
+ squeezed = _op.squeeze(in_expr, axis=squeeze_axis)
+ if isinstance(squeezed, _expr.TupleWrapper):
+ squeezed = squeezed[0]
+ else:
+ splitted = _op.split(in_expr,
+ indices_or_sections=num_unpacks,
+ axis=unpack_axis)
+ squeezed = _expr.TupleWrapper(
+ _expr.Tuple([_op.squeeze(split_item, axis=squeeze_axis) \
+ for split_item in splitted]), len(splitted))
+ out = squeezed
+
+ return out
Review comment:
just "return squeezed"?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services