This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 70399da0a2 [TFLite] Support for BATCH_MATMUL tflite operator (#14423)
70399da0a2 is described below
commit 70399da0a235dfe999bd676b6e6ffc6af8c33e5f
Author: neildhickey <[email protected]>
AuthorDate: Thu Mar 30 20:56:36 2023 +0100
[TFLite] Support for BATCH_MATMUL tflite operator (#14423)
* [TFLite] Support for BATCH_MATMUL tflite operator
Adds support for BATCH_MATMUL operator in the TFLite frontend.
Adds a test that checks supported TFLite types.
* Fixing linting issues
* Fixing more lint issues
* Fixing compare_tflite function for input_tensors < 2
---
python/tvm/relay/frontend/tflite.py | 147 +++++++++++++++++++++++++++
tests/python/frontend/tflite/test_forward.py | 74 ++++++++++++--
2 files changed, 212 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relay/frontend/tflite.py
b/python/tvm/relay/frontend/tflite.py
index db21fa6668..9daf7f716f 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -32,7 +32,9 @@ from .. import function as _function
from .. import op as _op
from .. import qnn as _qnn
from .common import ExprTable
+from .common import fold_constant as _fold_constant
from .common import infer_shape as _infer_shape
+from .common import infer_type as _infer_type
from .common import lstm_cell, to_int_list, shape_of, try_infer_value
from .common import set_span
from .tflite_flexbuffer import FlexBufferDecoder
@@ -80,6 +82,7 @@ class OperatorConverter(object):
"ARG_MIN": self.convert_arg_min,
"AVERAGE_POOL_2D": self.convert_average_pool2d,
"BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
+ "BATCH_MATMUL": self.convert_batch_matmul,
"CAST": self.convert_cast,
"CEIL": self.convert_ceil,
"CONCATENATION": self.convert_concatenation,
@@ -492,6 +495,21 @@ class OperatorConverter(object):
"Tensor type {} is currently not
supported".format(str(tensor_type))
)
+ def flatten_to_nd(self, x, x_shape, nd=3):
+ """Flatten input tensor to nd rank"""
+ ndims = _infer_shape(x_shape)[0]
+ if ndims == nd:
+ return x
+ newshape = _op.concatenate(
+ [
+ _expr.const([-1],
dtype=_infer_type(x_shape).checked_type.dtype),
+ _op.strided_slice(x_shape, [ndims - nd + 1], [ndims]),
+ ],
+ 0,
+ )
+ out = _op.reshape(x, _fold_constant(newshape))
+ return out
+
def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
lhs_scale = lhs_tensor.qnn_params["scale"]
rhs_scale = rhs_tensor.qnn_params["scale"]
@@ -2959,6 +2977,135 @@ class OperatorConverter(object):
return out
+ def convert_batch_matmul(self, op):
+ """batch_matmul implementation."""
+ try:
+ from tflite.BatchMatMulOptions import BatchMatMulOptions
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
+ input_tensors = self.get_input_tensors(op)
+
+ assert len(input_tensors) == 2, "two input tensor arguments expected"
+
+ batch_matmul_options = BatchMatMulOptions()
+ op_options = op.BuiltinOptions()
+ batch_matmul_options.Init(op_options.Bytes, op_options.Pos)
+
+ input_a = self.get_expr(input_tensors[0].tensor_idx)
+ input_b = self.get_expr(input_tensors[1].tensor_idx)
+
+ shape_a = shape_of(input_a)
+ shape_b = shape_of(input_b)
+ rank_a = _infer_shape(shape_a)[0]
+ rank_b = _infer_shape(shape_b)[0]
+
+ if rank_a > 2 or rank_b > 2:
+ # Determine the output batch dimension
+ new_a_shape = shape_a
+ new_b_shape = shape_b
+ if rank_a > rank_b:
+ rank_diff = rank_a - rank_b
+ new_b_shape = _op.concatenate(
+ [
+ _expr.const([1] * rank_diff,
dtype=_infer_type(b_shape).checked_type.dtype),
+ shape_b,
+ ],
+ 0,
+ )
+ elif rank_a < rank_b:
+ rank_diff = rank_b - rank_a
+ new_a_shape = _op.concatenate(
+ [
+ _expr.const([1] * rank_diff,
dtype=_infer_type(a_shape).checked_type.dtype),
+ shape_a,
+ ],
+ 0,
+ )
+ else:
+ pass
+
+ out_batch = _op.concatenate(
+ [
+ _op.maximum(
+ _op.strided_slice(new_b_shape, [i], [i + 1]),
+ _op.strided_slice(new_a_shape, [i], [i + 1]),
+ )
+ for i in range(max(rank_a, rank_b) - 2)
+ ],
+ 0,
+ )
+
+ a_broadcasted_shape = _fold_constant(
+ _op.concatenate(
+ [
+ out_batch,
+ _op.strided_slice(shape_a, [rank_a - 2], [rank_a]),
+ ],
+ 0,
+ )
+ )
+ b_broadcasted_shape = _fold_constant(
+ _op.concatenate(
+ [
+ out_batch,
+ _op.strided_slice(shape_b, [rank_b - 2], [rank_b]),
+ ],
+ 0,
+ )
+ )
+ if not tvm.ir.structural_equal(shape_a, a_broadcasted_shape):
+ input_a = _op.transform.broadcast_to(a, a_broadcasted_shape)
+ if not tvm.ir.structural_equal(shape_b, b_broadcasted_shape):
+ input_b = _op.transform.broadcast_to(b, b_broadcasted_shape)
+
+ input_a = self.flatten_to_nd(input_a, shape_a, 3)
+ input_b = self.flatten_to_nd(input_b, shape_b, 3)
+
+ if batch_matmul_options.AdjX():
+ input_a = _op.transpose(input_a, [0, 2, 1])
+ if not batch_matmul_options.AdjY():
+ input_b = _op.transpose(input_b, [0, 2, 1])
+
+ if self.is_quantized(op):
+ output = _qnn.op.batch_matmul(
+ input_a,
+ input_b,
+ relay.const(0, "int32"),
+ relay.const(0, "int32"),
+ relay.const(1.0, "float32"),
+ relay.const(1.0, "float32"),
+ )
+ else:
+ output = _op.nn.batch_matmul(input_a, input_b)
+
+ # Reshape output to original dimensions.
+ output_shape = shape_of(output)
+
+ rank_out = _infer_shape(output_shape)[0]
+
+ final_shape = _op.concatenate(
+ [
+ _op.strided_slice(shape_a, [0], [rank_a - 2]),
+ _op.strided_slice(output_shape, [rank_out - 2], [rank_out]),
+ ],
+ 0,
+ )
+
+ reshape = _op.reshape(output, _fold_constant(final_shape))
+ # qnn batch matmul returns a int32 tensor so we need to requantize
+ if self.is_quantized(op):
+ return _qnn.op.requantize(
+ reshape,
+ relay.const(1.0, "float32"),
+ relay.const(0, "int32"),
+ relay.const(1.0, "float32"),
+ relay.const(0, "int32"),
+ out_dtype="int8",
+ )
+ else:
+ return reshape
+
def convert_space_to_batch_nd(self, op):
"""space_to_batch_nd implementation."""
input_tensors = self.get_input_tensors(op)
diff --git a/tests/python/frontend/tflite/test_forward.py
b/tests/python/frontend/tflite/test_forward.py
index 42a27bbd26..41eb1f3067 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -61,6 +61,7 @@ from tensorflow.python.ops import image_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import variables
+from tensorflow import raw_ops
try:
from tensorflow import lite as interpreter_wrapper
@@ -319,6 +320,13 @@ def compare_tflite_with_tvm(
sess.run(variables.global_variables_initializer())
# convert to tflite model
converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors,
output_tensors)
+
+ if len(input_tensors) > 1:
+ if len(input_tensors[0].shape) <= 4 and
len(input_tensors[1].shape) <= 4:
+ converter._experimental_disable_batchmatmul_unfold = True
+ else:
+ converter._experimental_disable_batchmatmul_unfold = False
+
converter.experimental_new_converter = experimental_new_converter
if quantized:
if int_quant_dtype == tf.int16:
@@ -734,24 +742,72 @@ def test_forward_cast():
#######################################################################
# Batch Mat Mul
# ----
-def _test_batch_matmul(a_shape, b_shape, dtype, adjoint_a=False,
adjoint_b=False):
+def _test_batch_matmul(
+ a_shape, b_shape, dtype, out_dtype, adjoint_a=False, adjoint_b=False,
quantized=False
+):
with tf.Graph().as_default():
a = array_ops.placeholder(shape=a_shape, dtype=dtype, name="A")
b = array_ops.placeholder(shape=b_shape, dtype=dtype, name="B")
- result = math_ops.matmul(a, b, adjoint_a=adjoint_a,
adjoint_b=adjoint_b, name="batchmatmul")
+ print(tf.__version__)
+
+ result = raw_ops.BatchMatMulV3(
+ x=a, y=b, Tout=out_dtype, adj_x=adjoint_a, adj_y=adjoint_b,
name="batchmatmul"
+ )
+ input_range = {"A": (-100, 100), "B": (-100, 100)} if quantized else
None
a_np = np.random.uniform(high=5.0, size=a_shape).astype(dtype)
b_np = np.random.uniform(high=5.0, size=b_shape).astype(dtype)
- compare_tflite_with_tvm([a_np, b_np], [a.name, b.name], [a, b],
[result])
+ compare_tflite_with_tvm(
+ [a_np, b_np],
+ [a.name, b.name],
+ [a, b],
+ [result],
+ experimental_new_converter=True,
+ quantized=quantized,
+ input_range=input_range,
+ )
-def test_forward_batch_matmul():
[email protected]("config", [("int8", "int32", True), ("float32",
"float32", False)])
+def test_forward_batch_matmul(config):
"""BATCH_MAT_MUL"""
- _test_batch_matmul((3, 5, 4), (3, 4, 5), "float32")
- _test_batch_matmul((3, 5, 4), (3, 4, 5), "float32", True, True)
- _test_batch_matmul((3, 5, 4), (3, 5, 4), "float32", True, False)
- _test_batch_matmul((3, 5, 4), (3, 5, 4), "float32", False, True)
- _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), "float32")
+ _test_batch_matmul(
+ (3, 5, 4), (3, 4, 5), dtype=config[0], out_dtype=config[1],
quantized=config[2]
+ )
+ _test_batch_matmul(
+ (3, 5, 4),
+ (3, 4, 5),
+ dtype=config[0],
+ out_dtype=config[1],
+ adjoint_a=True,
+ adjoint_b=True,
+ quantized=config[2],
+ )
+ _test_batch_matmul(
+ (3, 5, 4),
+ (3, 5, 4),
+ dtype=config[0],
+ out_dtype=config[1],
+ adjoint_a=True,
+ adjoint_b=False,
+ quantized=config[2],
+ )
+ _test_batch_matmul(
+ (3, 5, 4),
+ (3, 5, 4),
+ dtype=config[0],
+ out_dtype=config[1],
+ adjoint_a=False,
+ adjoint_b=True,
+ quantized=config[2],
+ )
+ _test_batch_matmul(
+ (3, 4, 5, 6), (3, 4, 6, 5), dtype=config[0], out_dtype=config[1],
quantized=config[2]
+ )
+ # BatchMatMul doesn't support larger than 4D tensors
+ # _test_batch_matmul(
+ # (2, 3, 4, 5, 6), (2, 3, 4, 6, 5), dtype=config[0],
out_dtype=config[1], quantized=config[2]
+ # )
#######################################################################