This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 70399da0a2 [TFLite] Support for BATCH_MATMUL tflite operator (#14423)
70399da0a2 is described below

commit 70399da0a235dfe999bd676b6e6ffc6af8c33e5f
Author: neildhickey <[email protected]>
AuthorDate: Thu Mar 30 20:56:36 2023 +0100

    [TFLite] Support for BATCH_MATMUL tflite operator (#14423)
    
    * [TFLite] Support for BATCH_MATMUL tflite operator
    
    Adds support for BATCH_MATMUL operator in the TFLite frontend.
    
    Adds a test that checks supported TFLite types.
    
    * Fixing linting issues
    
    * Fixing more lint issues
    
    * Fixing compare_tflite function for input_tensors < 2
---
 python/tvm/relay/frontend/tflite.py          | 147 +++++++++++++++++++++++++++
 tests/python/frontend/tflite/test_forward.py |  74 ++++++++++++--
 2 files changed, 212 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relay/frontend/tflite.py 
b/python/tvm/relay/frontend/tflite.py
index db21fa6668..9daf7f716f 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -32,7 +32,9 @@ from .. import function as _function
 from .. import op as _op
 from .. import qnn as _qnn
 from .common import ExprTable
+from .common import fold_constant as _fold_constant
 from .common import infer_shape as _infer_shape
+from .common import infer_type as _infer_type
 from .common import lstm_cell, to_int_list, shape_of, try_infer_value
 from .common import set_span
 from .tflite_flexbuffer import FlexBufferDecoder
@@ -80,6 +82,7 @@ class OperatorConverter(object):
             "ARG_MIN": self.convert_arg_min,
             "AVERAGE_POOL_2D": self.convert_average_pool2d,
             "BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
+            "BATCH_MATMUL": self.convert_batch_matmul,
             "CAST": self.convert_cast,
             "CEIL": self.convert_ceil,
             "CONCATENATION": self.convert_concatenation,
@@ -492,6 +495,21 @@ class OperatorConverter(object):
             "Tensor type {} is currently not 
supported".format(str(tensor_type))
         )
 
+    def flatten_to_nd(self, x, x_shape, nd=3):
+        """Flatten input tensor to nd rank"""
+        ndims = _infer_shape(x_shape)[0]
+        if ndims == nd:
+            return x
+        newshape = _op.concatenate(
+            [
+                _expr.const([-1], 
dtype=_infer_type(x_shape).checked_type.dtype),
+                _op.strided_slice(x_shape, [ndims - nd + 1], [ndims]),
+            ],
+            0,
+        )
+        out = _op.reshape(x, _fold_constant(newshape))
+        return out
+
     def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
         lhs_scale = lhs_tensor.qnn_params["scale"]
         rhs_scale = rhs_tensor.qnn_params["scale"]
@@ -2959,6 +2977,135 @@ class OperatorConverter(object):
 
         return out
 
+    def convert_batch_matmul(self, op):
+        """batch_matmul implementation."""
+        try:
+            from tflite.BatchMatMulOptions import BatchMatMulOptions
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        input_tensors = self.get_input_tensors(op)
+
+        assert len(input_tensors) == 2, "two input tensor arguments expected"
+
+        batch_matmul_options = BatchMatMulOptions()
+        op_options = op.BuiltinOptions()
+        batch_matmul_options.Init(op_options.Bytes, op_options.Pos)
+
+        input_a = self.get_expr(input_tensors[0].tensor_idx)
+        input_b = self.get_expr(input_tensors[1].tensor_idx)
+
+        shape_a = shape_of(input_a)
+        shape_b = shape_of(input_b)
+        rank_a = _infer_shape(shape_a)[0]
+        rank_b = _infer_shape(shape_b)[0]
+
+        if rank_a > 2 or rank_b > 2:
+            # Determine the output batch dimension
+            new_a_shape = shape_a
+            new_b_shape = shape_b
+            if rank_a > rank_b:
+                rank_diff = rank_a - rank_b
+                new_b_shape = _op.concatenate(
+                    [
+                        _expr.const([1] * rank_diff, 
dtype=_infer_type(b_shape).checked_type.dtype),
+                        shape_b,
+                    ],
+                    0,
+                )
+            elif rank_a < rank_b:
+                rank_diff = rank_b - rank_a
+                new_a_shape = _op.concatenate(
+                    [
+                        _expr.const([1] * rank_diff, 
dtype=_infer_type(a_shape).checked_type.dtype),
+                        shape_a,
+                    ],
+                    0,
+                )
+            else:
+                pass
+
+            out_batch = _op.concatenate(
+                [
+                    _op.maximum(
+                        _op.strided_slice(new_b_shape, [i], [i + 1]),
+                        _op.strided_slice(new_a_shape, [i], [i + 1]),
+                    )
+                    for i in range(max(rank_a, rank_b) - 2)
+                ],
+                0,
+            )
+
+            a_broadcasted_shape = _fold_constant(
+                _op.concatenate(
+                    [
+                        out_batch,
+                        _op.strided_slice(shape_a, [rank_a - 2], [rank_a]),
+                    ],
+                    0,
+                )
+            )
+            b_broadcasted_shape = _fold_constant(
+                _op.concatenate(
+                    [
+                        out_batch,
+                        _op.strided_slice(shape_b, [rank_b - 2], [rank_b]),
+                    ],
+                    0,
+                )
+            )
+            if not tvm.ir.structural_equal(shape_a, a_broadcasted_shape):
+                input_a = _op.transform.broadcast_to(a, a_broadcasted_shape)
+            if not tvm.ir.structural_equal(shape_b, b_broadcasted_shape):
+                input_b = _op.transform.broadcast_to(b, b_broadcasted_shape)
+
+            input_a = self.flatten_to_nd(input_a, shape_a, 3)
+            input_b = self.flatten_to_nd(input_b, shape_b, 3)
+
+            if batch_matmul_options.AdjX():
+                input_a = _op.transpose(input_a, [0, 2, 1])
+            if not batch_matmul_options.AdjY():
+                input_b = _op.transpose(input_b, [0, 2, 1])
+
+            if self.is_quantized(op):
+                output = _qnn.op.batch_matmul(
+                    input_a,
+                    input_b,
+                    relay.const(0, "int32"),
+                    relay.const(0, "int32"),
+                    relay.const(1.0, "float32"),
+                    relay.const(1.0, "float32"),
+                )
+            else:
+                output = _op.nn.batch_matmul(input_a, input_b)
+
+            # Reshape output to original dimensions.
+            output_shape = shape_of(output)
+
+            rank_out = _infer_shape(output_shape)[0]
+
+        final_shape = _op.concatenate(
+            [
+                _op.strided_slice(shape_a, [0], [rank_a - 2]),
+                _op.strided_slice(output_shape, [rank_out - 2], [rank_out]),
+            ],
+            0,
+        )
+
+        reshape = _op.reshape(output, _fold_constant(final_shape))
+        # qnn batch matmul returns a int32 tensor so we need to requantize
+        if self.is_quantized(op):
+            return _qnn.op.requantize(
+                reshape,
+                relay.const(1.0, "float32"),
+                relay.const(0, "int32"),
+                relay.const(1.0, "float32"),
+                relay.const(0, "int32"),
+                out_dtype="int8",
+            )
+        else:
+            return reshape
+
     def convert_space_to_batch_nd(self, op):
         """space_to_batch_nd implementation."""
         input_tensors = self.get_input_tensors(op)
diff --git a/tests/python/frontend/tflite/test_forward.py 
b/tests/python/frontend/tflite/test_forward.py
index 42a27bbd26..41eb1f3067 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -61,6 +61,7 @@ from tensorflow.python.ops import image_ops
 from tensorflow.python.ops import gen_array_ops
 from tensorflow.python.ops import nn_impl
 from tensorflow.python.ops import variables
+from tensorflow import raw_ops
 
 try:
     from tensorflow import lite as interpreter_wrapper
@@ -319,6 +320,13 @@ def compare_tflite_with_tvm(
             sess.run(variables.global_variables_initializer())
         # convert to tflite model
         converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, 
output_tensors)
+
+        if len(input_tensors) > 1:
+            if len(input_tensors[0].shape) <= 4 and 
len(input_tensors[1].shape) <= 4:
+                converter._experimental_disable_batchmatmul_unfold = True
+            else:
+                converter._experimental_disable_batchmatmul_unfold = False
+
         converter.experimental_new_converter = experimental_new_converter
         if quantized:
             if int_quant_dtype == tf.int16:
@@ -734,24 +742,72 @@ def test_forward_cast():
 #######################################################################
 # Batch Mat Mul
 # ----
-def _test_batch_matmul(a_shape, b_shape, dtype, adjoint_a=False, 
adjoint_b=False):
+def _test_batch_matmul(
+    a_shape, b_shape, dtype, out_dtype, adjoint_a=False, adjoint_b=False, 
quantized=False
+):
     with tf.Graph().as_default():
         a = array_ops.placeholder(shape=a_shape, dtype=dtype, name="A")
         b = array_ops.placeholder(shape=b_shape, dtype=dtype, name="B")
-        result = math_ops.matmul(a, b, adjoint_a=adjoint_a, 
adjoint_b=adjoint_b, name="batchmatmul")
+        print(tf.__version__)
+
+        result = raw_ops.BatchMatMulV3(
+            x=a, y=b, Tout=out_dtype, adj_x=adjoint_a, adj_y=adjoint_b, 
name="batchmatmul"
+        )
+        input_range = {"A": (-100, 100), "B": (-100, 100)} if quantized else 
None
 
         a_np = np.random.uniform(high=5.0, size=a_shape).astype(dtype)
         b_np = np.random.uniform(high=5.0, size=b_shape).astype(dtype)
-        compare_tflite_with_tvm([a_np, b_np], [a.name, b.name], [a, b], 
[result])
+        compare_tflite_with_tvm(
+            [a_np, b_np],
+            [a.name, b.name],
+            [a, b],
+            [result],
+            experimental_new_converter=True,
+            quantized=quantized,
+            input_range=input_range,
+        )
 
 
-def test_forward_batch_matmul():
[email protected]("config", [("int8", "int32", True), ("float32", 
"float32", False)])
+def test_forward_batch_matmul(config):
     """BATCH_MAT_MUL"""
-    _test_batch_matmul((3, 5, 4), (3, 4, 5), "float32")
-    _test_batch_matmul((3, 5, 4), (3, 4, 5), "float32", True, True)
-    _test_batch_matmul((3, 5, 4), (3, 5, 4), "float32", True, False)
-    _test_batch_matmul((3, 5, 4), (3, 5, 4), "float32", False, True)
-    _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), "float32")
+    _test_batch_matmul(
+        (3, 5, 4), (3, 4, 5), dtype=config[0], out_dtype=config[1], 
quantized=config[2]
+    )
+    _test_batch_matmul(
+        (3, 5, 4),
+        (3, 4, 5),
+        dtype=config[0],
+        out_dtype=config[1],
+        adjoint_a=True,
+        adjoint_b=True,
+        quantized=config[2],
+    )
+    _test_batch_matmul(
+        (3, 5, 4),
+        (3, 5, 4),
+        dtype=config[0],
+        out_dtype=config[1],
+        adjoint_a=True,
+        adjoint_b=False,
+        quantized=config[2],
+    )
+    _test_batch_matmul(
+        (3, 5, 4),
+        (3, 5, 4),
+        dtype=config[0],
+        out_dtype=config[1],
+        adjoint_a=False,
+        adjoint_b=True,
+        quantized=config[2],
+    )
+    _test_batch_matmul(
+        (3, 4, 5, 6), (3, 4, 6, 5), dtype=config[0], out_dtype=config[1], 
quantized=config[2]
+    )
+    # BatchMatMul doesn't support larger than 4D tensors
+    # _test_batch_matmul(
+    #    (2, 3, 4, 5, 6), (2, 3, 4, 6, 5), dtype=config[0], 
out_dtype=config[1], quantized=config[2]
+    # )
 
 
 #######################################################################

Reply via email to