This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 33232deefb [FRONTEND][TFLITE][BugFix] Fix variable typo in batchmatmul 
converting func (#15259)
33232deefb is described below

commit 33232deefbfe6e49c291136147743f0c27c8ff54
Author: Sungho Shin <[email protected]>
AuthorDate: Wed Jul 12 05:01:19 2023 +0900

    [FRONTEND][TFLITE][BugFix] Fix variable typo in batchmatmul converting func 
(#15259)
    
    * TFLite frontend bug fix
    
    * Update tflite.py
    
    * lint
    
    * Add pytest
---
 python/tvm/relay/frontend/tflite.py          | 12 ++++++++----
 tests/python/frontend/tflite/test_forward.py | 18 ++++++++++++++++++
 2 files changed, 26 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relay/frontend/tflite.py 
b/python/tvm/relay/frontend/tflite.py
index 9e88a85e03..dfc7ed27a4 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -3006,7 +3006,9 @@ class OperatorConverter(object):
                 rank_diff = rank_a - rank_b
                 new_b_shape = _op.concatenate(
                     [
-                        _expr.const([1] * rank_diff, 
dtype=_infer_type(b_shape).checked_type.dtype),
+                        _expr.const(
+                            [1] * rank_diff, 
dtype=_infer_type(new_b_shape).checked_type.dtype
+                        ),
                         shape_b,
                     ],
                     0,
@@ -3015,7 +3017,9 @@ class OperatorConverter(object):
                 rank_diff = rank_b - rank_a
                 new_a_shape = _op.concatenate(
                     [
-                        _expr.const([1] * rank_diff, 
dtype=_infer_type(a_shape).checked_type.dtype),
+                        _expr.const(
+                            [1] * rank_diff, 
dtype=_infer_type(new_a_shape).checked_type.dtype
+                        ),
                         shape_a,
                     ],
                     0,
@@ -3041,9 +3045,9 @@ class OperatorConverter(object):
                 _op.concatenate([out_batch, _op.strided_slice(shape_b, [rank_b 
- 2], [rank_b])], 0)
             )
             if not tvm.ir.structural_equal(shape_a, a_broadcasted_shape):
-                input_a = _op.transform.broadcast_to(a, a_broadcasted_shape)
+                input_a = _op.transform.broadcast_to(input_a, 
a_broadcasted_shape)
             if not tvm.ir.structural_equal(shape_b, b_broadcasted_shape):
-                input_b = _op.transform.broadcast_to(b, b_broadcasted_shape)
+                input_b = _op.transform.broadcast_to(input_b, 
b_broadcasted_shape)
 
             input_a = self.flatten_to_nd(input_a, shape_a, 3)
             input_b = self.flatten_to_nd(input_b, shape_b, 3)
diff --git a/tests/python/frontend/tflite/test_forward.py 
b/tests/python/frontend/tflite/test_forward.py
index c65e48b402..4ea82e5b4c 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -794,6 +794,15 @@ def test_forward_batch_matmul(config):
         adjoint_b=False,
         quantized=config[2],
     )
+    _test_batch_matmul(
+        (2, 3, 5, 4),
+        (1, 3, 5, 4),
+        dtype=config[0],
+        out_dtype=config[1],
+        adjoint_a=True,
+        adjoint_b=False,
+        quantized=config[2],
+    )
     _test_batch_matmul(
         (3, 5, 4),
         (3, 5, 4),
@@ -803,6 +812,15 @@ def test_forward_batch_matmul(config):
         adjoint_b=True,
         quantized=config[2],
     )
+    _test_batch_matmul(
+        (2, 3, 5, 4),
+        (1, 3, 5, 4),
+        dtype=config[0],
+        out_dtype=config[1],
+        adjoint_a=False,
+        adjoint_b=True,
+        quantized=config[2],
+    )
     _test_batch_matmul(
         (3, 4, 5, 6), (3, 4, 6, 5), dtype=config[0], out_dtype=config[1], 
quantized=config[2]
     )

Reply via email to