This is an automated email from the ASF dual-hosted git repository.

anijain2305 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2ee860e  [TFLite] Cast operator adapted for MLIR-based convertor 
(#7639)
2ee860e is described below

commit 2ee860e902e77f45996a5585fc09c5e5c29788e1
Author: Dmitriy Smirnov <[email protected]>
AuthorDate: Fri Mar 19 06:47:45 2021 +0000

    [TFLite] Cast operator adapted for MLIR-based convertor (#7639)
    
    * [TFLite] Cast operator adapted for MLIR-based convertor
    
    Cast operator now can be executed in MLIR-based version.
    Unit test updated
    
    Change-Id: I30e5c1c9d69355116b560af8f6d0582b2d593538
    
    * Comment added
    
    Change-Id: I3e2d29ef201283de337168d0b82679b63ca2fcf4
---
 python/tvm/relay/frontend/tflite.py          | 17 ++++++++++++-----
 tests/python/frontend/tflite/test_forward.py | 19 ++++++++++++++-----
 2 files changed, 26 insertions(+), 10 deletions(-)

diff --git a/python/tvm/relay/frontend/tflite.py 
b/python/tvm/relay/frontend/tflite.py
index d6f7047..a5c9a58 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -2336,11 +2336,18 @@ class OperatorConverter(object):
         input_tensor = input_tensors[0]
         in_expr = self.get_expr(input_tensor.tensor_idx)
 
-        assert op.BuiltinOptionsType() == BuiltinOptions.CastOptions
-        op_options = op.BuiltinOptions()
-        cast_options = CastOptions()
-        cast_options.Init(op_options.Bytes, op_options.Pos)
-        cast_dtype = cast_options.OutDataType()
+        # MLIR-based converter outputs no BuiltinOptions for Cast operator. In 
this
+        # case the output type can be derived from the Cast operator output 
tensor.
+        # When TOCO converter is used there will be "normal" 
BuiltinOptions.CastOptions
+        # with output type.
+        if op.BuiltinOptions() is not None:
+            assert op.BuiltinOptionsType() == BuiltinOptions.CastOptions
+            op_options = op.BuiltinOptions()
+            cast_options = CastOptions()
+            cast_options.Init(op_options.Bytes, op_options.Pos)
+            cast_dtype = cast_options.OutDataType()
+        else:
+            cast_dtype = self.get_output_tensors(op)[0].tensor.Type()
 
         out = _op.cast(in_expr, self.get_tensor_type_str(cast_dtype))
 
diff --git a/tests/python/frontend/tflite/test_forward.py 
b/tests/python/frontend/tflite/test_forward.py
index 0d02c15..7c12cd3 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -647,19 +647,28 @@ def test_forward_transpose():
 # ----
 
 
-def _test_cast(data, cast_dtype):
+def _test_cast(data, cast_dtype, use_mlir=False):
     """ One iteration of CAST """
     with tf.Graph().as_default():
         in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
         out = math_ops.cast(in_data, cast_dtype)
-        compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
+        compare_tflite_with_tvm(
+            data, "Placeholder:0", [in_data], [out], 
experimental_new_converter=use_mlir
+        )
 
 
 def test_forward_cast():
     """ CAST """
-    _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), 
cast_dtype=tf.int32)
-    _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), 
cast_dtype=tf.uint8)
-    _test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), 
cast_dtype=tf.int64)
+    for use_mlir in [False, True]:
+        _test_cast(
+            np.arange(6.0, dtype=np.float32).reshape((1, 6)), 
cast_dtype=tf.int32, use_mlir=use_mlir
+        )
+        _test_cast(
+            np.arange(6.0, dtype=np.float32).reshape((1, 6)), 
cast_dtype=tf.uint8, use_mlir=use_mlir
+        )
+        _test_cast(
+            np.arange(6.0, dtype=np.int32).reshape((1, 6)), 
cast_dtype=tf.int64, use_mlir=use_mlir
+        )
 
 
 #######################################################################

Reply via email to