This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 1aecc9b8b6 [Unity][Frontend][Onnx] Simplify gemm (#15458)
1aecc9b8b6 is described below
commit 1aecc9b8b622dbd870411c60d96a4843d1b75f4e
Author: Josh Fromm <[email protected]>
AuthorDate: Tue Aug 1 21:13:02 2023 -0700
[Unity][Frontend][Onnx] Simplify gemm (#15458)
Add simplification check to onnx gemm importer
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 10 +++++-----
tests/python/relax/test_frontend_onnx.py | 4 ++--
2 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 74eb904c4f..9a93f395ec 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -374,18 +374,18 @@ class Gemm(OnnxOpConverter):
# Compute Y = alpha * A X B + beta * C
- if alpha is not None:
- A = bb.normalize(relax.op.multiply(A, relax.const(alpha,
dtype=dtype)))
+ if alpha is not None and alpha != 1.0:
+ A = relax.op.multiply(A, relax.const(alpha, dtype=dtype))
if transA:
A = relax.op.permute_dims(A, [1, 0])
if transB:
B = relax.op.permute_dims(B, [1, 0])
- Y = bb.normalize(relax.op.matmul(A, B))
+ Y = relax.op.matmul(A, B)
if C is not None:
- if beta is not None:
- C = bb.normalize(relax.op.multiply(C, relax.const(beta,
dtype=dtype)))
+ if beta is not None and beta != 1.0:
+ C = relax.op.multiply(C, relax.const(beta, dtype=dtype))
Y = relax.op.add(Y, C)
return Y
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 3467e5bba2..647e72f04a 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -333,8 +333,8 @@ def test_gather():
_verify_gather([3, 3], [[0, 2]], [3, 1, 2], 1)
[email protected]("alpha", [None, 0.25])
[email protected]("beta", [None, 0.35])
[email protected]("alpha", [None, 0.25, 1.0])
[email protected]("beta", [None, 0.35, 1.0])
@pytest.mark.parametrize("useC", [False, True])
def test_gemm(alpha, beta, useC):
if useC: