This is an automated email from the ASF dual-hosted git repository.
leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 33232deefb [FRONTEND][TFLITE][BugFix] Fix variable typo in batchmatmul
converting func (#15259)
33232deefb is described below
commit 33232deefbfe6e49c291136147743f0c27c8ff54
Author: Sungho Shin <[email protected]>
AuthorDate: Wed Jul 12 05:01:19 2023 +0900
[FRONTEND][TFLITE][BugFix] Fix variable typo in batchmatmul converting func
(#15259)
* TFLite frontend bug fix
* Update tflite.py
* lint
* Add pytest
---
python/tvm/relay/frontend/tflite.py | 12 ++++++++----
tests/python/frontend/tflite/test_forward.py | 18 ++++++++++++++++++
2 files changed, 26 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relay/frontend/tflite.py
b/python/tvm/relay/frontend/tflite.py
index 9e88a85e03..dfc7ed27a4 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -3006,7 +3006,9 @@ class OperatorConverter(object):
rank_diff = rank_a - rank_b
new_b_shape = _op.concatenate(
[
- _expr.const([1] * rank_diff,
dtype=_infer_type(b_shape).checked_type.dtype),
+ _expr.const(
+ [1] * rank_diff,
dtype=_infer_type(new_b_shape).checked_type.dtype
+ ),
shape_b,
],
0,
@@ -3015,7 +3017,9 @@ class OperatorConverter(object):
rank_diff = rank_b - rank_a
new_a_shape = _op.concatenate(
[
- _expr.const([1] * rank_diff,
dtype=_infer_type(a_shape).checked_type.dtype),
+ _expr.const(
+ [1] * rank_diff,
dtype=_infer_type(new_a_shape).checked_type.dtype
+ ),
shape_a,
],
0,
@@ -3041,9 +3045,9 @@ class OperatorConverter(object):
_op.concatenate([out_batch, _op.strided_slice(shape_b, [rank_b
- 2], [rank_b])], 0)
)
if not tvm.ir.structural_equal(shape_a, a_broadcasted_shape):
- input_a = _op.transform.broadcast_to(a, a_broadcasted_shape)
+ input_a = _op.transform.broadcast_to(input_a,
a_broadcasted_shape)
if not tvm.ir.structural_equal(shape_b, b_broadcasted_shape):
- input_b = _op.transform.broadcast_to(b, b_broadcasted_shape)
+ input_b = _op.transform.broadcast_to(input_b,
b_broadcasted_shape)
input_a = self.flatten_to_nd(input_a, shape_a, 3)
input_b = self.flatten_to_nd(input_b, shape_b, 3)
diff --git a/tests/python/frontend/tflite/test_forward.py
b/tests/python/frontend/tflite/test_forward.py
index c65e48b402..4ea82e5b4c 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -794,6 +794,15 @@ def test_forward_batch_matmul(config):
adjoint_b=False,
quantized=config[2],
)
+ _test_batch_matmul(
+ (2, 3, 5, 4),
+ (1, 3, 5, 4),
+ dtype=config[0],
+ out_dtype=config[1],
+ adjoint_a=True,
+ adjoint_b=False,
+ quantized=config[2],
+ )
_test_batch_matmul(
(3, 5, 4),
(3, 5, 4),
@@ -803,6 +812,15 @@ def test_forward_batch_matmul(config):
adjoint_b=True,
quantized=config[2],
)
+ _test_batch_matmul(
+ (2, 3, 5, 4),
+ (1, 3, 5, 4),
+ dtype=config[0],
+ out_dtype=config[1],
+ adjoint_a=False,
+ adjoint_b=True,
+ quantized=config[2],
+ )
_test_batch_matmul(
(3, 4, 5, 6), (3, 4, 6, 5), dtype=config[0], out_dtype=config[1],
quantized=config[2]
)