This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e53cfe138c [Frontend][TFLite] Fix undefined symbols and Relay API 
remnants in TFLite frontend (#18929)
e53cfe138c is described below

commit e53cfe138c8660c473d7008c8ccb950ab673d5aa
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Mar 25 17:10:35 2026 -0400

    [Frontend][TFLite] Fix undefined symbols and Relay API remnants in TFLite 
frontend (#18929)
    
    The TFLite frontend was ported from Relay but contains several undefined
    symbols and Relay-specific APIs that cause runtime errors. This PR cleans
    up these issues so that working code paths are clean and broken paths fail
    with clear `NotImplementedError` instead of `NameError`.
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 181 ++++++++-------------
 tests/python/relax/test_frontend_tflite.py         |  28 ++++
 2 files changed, 100 insertions(+), 109 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 5ff0444e0b..0abd700562 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -18,7 +18,10 @@
 # pylint: disable=import-outside-toplevel, use-list-literal
 # pylint: disable=no-value-for-parameter, unused-variable
 # pylint: disable=unexpected-keyword-arg, unused-import, too-many-function-args
-# ruff: noqa: RUF005, F821, F841
+# ruff: noqa: RUF005
+# F821: _qnn and _expr references are in unreachable code paths (guarded by 
NotImplementedError)
+# and will be resolved when quantization and vision op support are added.
+# ruff: noqa: F821
 """Tensorflow lite frontend."""
 
 import functools
@@ -468,7 +471,9 @@ class OperatorConverter:
                     qnn_params = dict()
                     qnn_params["scale"] = relax.const(scale, "float32")
                     qnn_params["zero_point"] = relax.const(zero_point, "int32")
-                    raise NotImplementedError("Quantized operators not 
supported now")
+                    raise NotImplementedError(
+                        "Quantized TFLite models are not yet supported in the 
Relax frontend"
+                    )
             return_list.append(TensorWrapper(tensor_idx, tensor, buffer, 
qnn_params))
         return return_list
 
@@ -530,20 +535,14 @@ class OperatorConverter:
             return "bool"
         raise NotImplementedError(f"Tensor type {tensor_type!s} is currently 
not supported")
 
-    def flatten_to_nd(self, x, x_shape, nd=3):
+    def flatten_to_nd(self, x, nd=3):
         """Flatten input tensor to nd rank"""
-        ndims = self._infer_shape(x_shape)[0]
+        shape = x.struct_info.shape
+        ndims = len(shape)
         if ndims == nd:
             return x
-        newshape = relax.op.concat(
-            [
-                relax.const([-1], 
dtype=self._infer_type(x_shape).checked_type.dtype),
-                relax.op.strided_slice(x_shape, [ndims - nd + 1], [ndims]),
-            ],
-            0,
-        )
-        out = relax.op.reshape(x, self._fold_constant(newshape))
-        return out
+        new_shape = [-1] + [int(shape[i]) for i in range(ndims - nd + 1, 
ndims)]
+        return relax.op.reshape(x, new_shape)
 
     def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
         lhs_scale = lhs_tensor.qnn_params["scale"]
@@ -709,7 +708,7 @@ class OperatorConverter:
 
         # ResizeNearestNeighborOptions was added in tflite v1.13
         tflite_ver = 1120
-        if "ResizeNearestNeighborOptions" in dir(tflite.BuiltinOptions):
+        if hasattr(BuiltinOptions, "ResizeNearestNeighborOptions"):
             tflite_ver = 1130
 
         input_tensors = self.get_input_tensors(op)
@@ -947,8 +946,7 @@ class OperatorConverter:
         shape_options = ShapeOptions()
         shape_options.Init(op_options.Bytes, op_options.Pos)
 
-        out_type = self.get_tensor_type_str(shape_options.OutType())
-        out = shape_of(self.get_tensor_expr(input_tensors[0]), dtype=out_type)
+        out = relax.op.shape_of(self.get_tensor_expr(input_tensors[0]))
 
         return out
 
@@ -1428,6 +1426,7 @@ class OperatorConverter:
 
         from tflite.BuiltinOptions import BuiltinOptions
         from tflite.GatherOptions import GatherOptions
+        from tflite.TensorType import TensorType
 
         input_tensors = self.get_input_tensors(op)
         assert len(input_tensors) == 2, "input tensors length should be 2"
@@ -2804,6 +2803,11 @@ class OperatorConverter:
 
         assert len(input_tensors) == 2, "two input tensor arguments expected"
 
+        if self.is_quantized(op):
+            raise NotImplementedError(
+                "Quantized BATCH_MATMUL is not yet supported in the Relax 
frontend"
+            )
+
         batch_matmul_options = BatchMatMulOptions()
         op_options = op.BuiltinOptions()
         batch_matmul_options.Init(op_options.Bytes, op_options.Pos)
@@ -2811,108 +2815,54 @@ class OperatorConverter:
         input_a = self.get_expr(input_tensors[0].tensor_idx)
         input_b = self.get_expr(input_tensors[1].tensor_idx)
 
-        shape_a = shape_of(input_a)
-        shape_b = shape_of(input_b)
-        rank_a = self._infer_shape(shape_a)[0]
-        rank_b = self._infer_shape(shape_b)[0]
+        shape_a = list(input_a.struct_info.shape)
+        shape_b = list(input_b.struct_info.shape)
+        rank_a = len(shape_a)
+        rank_b = len(shape_b)
 
         if rank_a > 2 or rank_b > 2:
-            # Determine the output batch dimension
-            new_a_shape = shape_a
-            new_b_shape = shape_b
-            if rank_a > rank_b:
-                rank_diff = rank_a - rank_b
-                new_b_shape = relax.op.concat(
-                    [
-                        relax.const(
-                            [1] * rank_diff, 
dtype=self._infer_type(new_b_shape).checked_type.dtype
-                        ),
-                        shape_b,
-                    ],
-                    0,
-                )
-            elif rank_a < rank_b:
-                rank_diff = rank_b - rank_a
-                new_a_shape = relax.op.concat(
-                    [
-                        relax.const(
-                            [1] * rank_diff, 
dtype=self._infer_type(new_a_shape).checked_type.dtype
-                        ),
-                        shape_a,
-                    ],
-                    0,
-                )
-            else:
-                pass
+            # Broadcast batch dimensions
+            new_a_shape = [1] * max(0, rank_b - rank_a) + [int(s) for s in 
shape_a]
+            new_b_shape = [1] * max(0, rank_a - rank_b) + [int(s) for s in 
shape_b]
+            max_rank = max(rank_a, rank_b)
 
-            out_batch = relax.op.concat(
-                [
-                    relax.op.maximum(
-                        relax.op.strided_slice(new_b_shape, [i], [i + 1]),
-                        relax.op.strided_slice(new_a_shape, [i], [i + 1]),
-                    )
-                    for i in range(max(rank_a, rank_b) - 2)
-                ],
-                0,
-            )
+            batch_shape = [
+                max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 
2)
+            ]
 
-            a_broadcasted_shape = _fold_constant(
-                _op.concat([out_batch, _op.strided_slice(shape_a, [rank_a - 
2], [rank_a])], 0)
-            )
-            b_broadcasted_shape = _fold_constant(
-                _op.concat([out_batch, _op.strided_slice(shape_b, [rank_b - 
2], [rank_b])], 0)
-            )
-            if not tvm.ir.structural_equal(shape_a, a_broadcasted_shape):
-                input_a = relax.op.transform.broadcast_to(input_a, 
a_broadcasted_shape)
-            if not tvm.ir.structural_equal(shape_b, b_broadcasted_shape):
-                input_b = relax.op.transform.broadcast_to(input_b, 
b_broadcasted_shape)
+            a_broadcast = batch_shape + [int(shape_a[-2]), int(shape_a[-1])]
+            b_broadcast = batch_shape + [int(shape_b[-2]), int(shape_b[-1])]
 
-            input_a = self.flatten_to_nd(input_a, shape_a, 3)
-            input_b = self.flatten_to_nd(input_b, shape_b, 3)
+            if [int(s) for s in shape_a] != a_broadcast:
+                input_a = relax.op.broadcast_to(input_a, a_broadcast)
+            if [int(s) for s in shape_b] != b_broadcast:
+                input_b = relax.op.broadcast_to(input_b, b_broadcast)
 
-            if batch_matmul_options.AdjX():
+            input_a = self.flatten_to_nd(input_a, 3)
+            input_b = self.flatten_to_nd(input_b, 3)
+
+            adj_x = batch_matmul_options.AdjX()
+            adj_y = batch_matmul_options.AdjY()
+
+            if adj_x:
                 input_a = relax.op.permute_dims(input_a, [0, 2, 1])
-            if not batch_matmul_options.AdjY():
+            if adj_y:
                 input_b = relax.op.permute_dims(input_b, [0, 2, 1])
 
-            if self.is_quantized(op):
-                output = _qnn.op.batch_matmul(
-                    input_a,
-                    input_b,
-                    relax.const(0, "int32"),
-                    relax.const(0, "int32"),
-                    relax.const(1.0, "float32"),
-                    relax.const(1.0, "float32"),
-                )
-            else:
-                output = relax.op.nn.batch_matmul(input_a, input_b)
+            output = relax.op.matmul(input_a, input_b)
 
-            # Reshape output to original dimensions.
-            output_shape = shape_of(output)
+            # Compute output matmul dims from original shapes
+            m_dim = int(shape_a[-1]) if adj_x else int(shape_a[-2])
+            n_dim = int(shape_b[-2]) if adj_y else int(shape_b[-1])
+            final_shape = [int(s) for s in shape_a[: rank_a - 2]] + [m_dim, 
n_dim]
+            return relax.op.reshape(output, final_shape)
 
-            rank_out = self._infer_shape(output_shape)[0]
-
-        final_shape = relax.op.concat(
-            [
-                relax.op.strided_slice(shape_a, [0], [rank_a - 2]),
-                relax.op.strided_slice(output_shape, [rank_out - 2], 
[rank_out]),
-            ],
-            0,
-        )
-
-        reshape = relax.op.reshape(output, self._fold_constant(final_shape))
-        # qnn batch matmul returns a int32 tensor so we need to requantize
-        if self.is_quantized(op):
-            return _qnn.op.requantize(
-                reshape,
-                relax.const(1.0, "float32"),
-                relax.const(0, "int32"),
-                relax.const(1.0, "float32"),
-                relax.const(0, "int32"),
-                out_dtype="int8",
-            )
-        else:
-            return reshape
+        # rank <= 2: use matmul directly
+        if batch_matmul_options.AdjX():
+            input_a = relax.op.permute_dims(input_a)
+        if batch_matmul_options.AdjY():
+            input_b = relax.op.permute_dims(input_b)
+        return relax.op.matmul(input_a, input_b)
 
     def convert_space_to_batch_nd(self, op):
         """space_to_batch_nd implementation."""
@@ -2974,6 +2924,7 @@ class OperatorConverter:
 
     def convert_sparse_to_dense(self, op):
         """Convert TFLite SPARSE_TO_DENSE"""
+        from tflite.TensorType import TensorType
 
         input_tensors = self.get_input_tensors(op)
         assert len(input_tensors) == 4, "input tensors length should be 4"
@@ -3029,6 +2980,7 @@ class OperatorConverter:
 
         from tflite.BuiltinOptions import BuiltinOptions
         from tflite.Padding import Padding
+        from tflite.TensorType import TensorType
         from tflite.TransposeConvOptions import TransposeConvOptions
 
         input_tensors = self.get_input_tensors(op)
@@ -3226,6 +3178,7 @@ class OperatorConverter:
 
     def convert_dequantize(self, op):
         """Convert TFLite Dequantize"""
+        from tflite.TensorType import TensorType
 
         input_tensors = self.get_input_tensors(op)
         assert len(input_tensors) == 1, "input tensors length should be 1"
@@ -3251,6 +3204,11 @@ class OperatorConverter:
 
     def convert_detection_postprocess(self, op):
         """Convert TFLite_Detection_PostProcess"""
+        raise NotImplementedError(
+            "DETECTION_POSTPROCESS requires vision ops 
(multibox_transform_loc, "
+            "non_max_suppression, get_valid_counts) not yet available in 
Relax. "
+            "See https://github.com/apache/tvm/issues/XXXX";
+        )
         flexbuffer = op.CustomOptionsAsNumpy().tobytes()
         custom_options = FlexBufferDecoder(flexbuffer).decode()
 
@@ -3381,6 +3339,11 @@ class OperatorConverter:
     def convert_nms_v5(self, op):
         """Convert TFLite NonMaxSuppressionV5"""
         # 
https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/non-max-suppression-v5
+        raise NotImplementedError(
+            "NON_MAX_SUPPRESSION_V5 requires vision ops (get_valid_counts, "
+            "non_max_suppression) not yet available in Relax. "
+            "See https://github.com/apache/tvm/issues/XXXX";
+        )
 
         input_tensors = self.get_input_tensors(op)
         assert len(input_tensors) == 6, "input tensor length should be 6"
@@ -3843,7 +3806,7 @@ def prepare_dense_matrix_from_sparse(sparse_tensor, 
sparse_tensor_value, sparse_
 
 def get_scalar_from_constant(expr):
     """Returns scalar value from Relax constant scalar."""
-    assert isinstance(expr, _expr.Constant) and not expr.data.shape, (
+    assert isinstance(expr, relax.Constant) and not expr.data.shape, (
         "Expr is not a constant scalar."
     )
     value = expr.data.numpy()
@@ -4091,7 +4054,7 @@ def from_tflite(
 
     with bb.function("main"):
         input_list = []
-        with bb.dataflow() as df:  # pylint: disable=invalid-name, 
unused-variable
+        with bb.dataflow() as df:  # noqa: F841  # pylint: 
disable=invalid-name, unused-variable
             exp_tab = ExprTable()
             for model_input in model_inputs:
                 model_input_name = get_tensor_name(subgraph, model_input)
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 9d3d6a9aeb..e7d81cf5fe 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -825,5 +825,33 @@ def test_networks(net, shape):
     verify(concrete_func)
 
 
+def test_batch_matmul():
+    class BatchMatMul(tf.Module):
+        @tf.function(
+            input_signature=[
+                tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32),
+                tf.TensorSpec(shape=(2, 4, 5), dtype=tf.float32),
+            ]
+        )
+        def func(self, x, y):
+            return tf.matmul(x, y)
+
+    verify(BatchMatMul)
+
+
+def test_batch_matmul_adj():
+    class BatchMatMulAdj(tf.Module):
+        @tf.function(
+            input_signature=[
+                tf.TensorSpec(shape=(2, 4, 3), dtype=tf.float32),
+                tf.TensorSpec(shape=(2, 5, 4), dtype=tf.float32),
+            ]
+        )
+        def func(self, x, y):
+            return tf.matmul(x, y, transpose_a=True, transpose_b=True)
+
+    verify(BatchMatMulAdj)
+
+
 if __name__ == "__main__":
     pytest.main(["-s", __file__])

Reply via email to