inadob commented on a change in pull request #4704: [Relay][Frontend][TFLite] 
Add parser support for arg_min_max
URL: https://github.com/apache/incubator-tvm/pull/4704#discussion_r374611987
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -826,6 +828,50 @@ def _convert_reduce_prod(self, op):
     def _convert_reduce_sum(self, op):
         return self._convert_reduce(_op.reduce.sum, op)
 
+    def _convert_arg_min_max(self, relay_op, op):
+        """Generic method to convert TFLite arg_min_max"""
+        try:
+            from tflite.Operator import Operator
+            from tflite.BuiltinOptions import BuiltinOptions
+            from tflite.ArgMinOptions import ArgMinOptions
+            from tflite.ArgMaxOptions import ArgMaxOptions
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        assert isinstance(op, Operator)
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 2, "input tensors length should be 2"
+
+        input_tensor = input_tensors[0]
+        in_expr = self.get_expr(input_tensor.tensor_idx)
+        axis_tensor = input_tensors[1]
+        # we support the case when the axis is a scalar not a tensor
+        axis_value = int(self.get_tensor_value(axis_tensor))
+
+        if op.BuiltinOptionsType() == BuiltinOptions.ArgMinOptions:
+            arg_min_max_options = ArgMinOptions()
+        elif op.BuiltinOptionsType() == BuiltinOptions.ArgMaxOptions:
+            arg_min_max_options = ArgMaxOptions()
+        op_options = op.BuiltinOptions()
+        arg_min_max_options.Init(op_options.Bytes, op_options.Pos)
+        output_dtype = arg_min_max_options.OutputType()
+
+        # set keepdims to True since tflite 1.13 removes all dims of size 1
+        # WARNING: all other versions of tflite > 1.13 need keepdims=False
 
 Review comment:
   The issue here is that we can't use the same 'TFLite version' check as we do 
in the tests since I do not think there is a way to get the TFLite version used 
to convert a model that you pass to the parser.     

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to