This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 1aecc9b8b6 [Unity][Frontend][Onnx] Simplify gemm (#15458)
1aecc9b8b6 is described below

commit 1aecc9b8b622dbd870411c60d96a4843d1b75f4e
Author: Josh Fromm <jwfr...@octoml.ai>
AuthorDate: Tue Aug 1 21:13:02 2023 -0700

    [Unity][Frontend][Onnx] Simplify gemm (#15458)
    
    Add simplification check to onnx gemm importer
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 10 +++++-----
 tests/python/relax/test_frontend_onnx.py        |  4 ++--
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 74eb904c4f..9a93f395ec 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -374,18 +374,18 @@ class Gemm(OnnxOpConverter):
 
         # Compute Y = alpha * A X B + beta * C
 
-        if alpha is not None:
-            A = bb.normalize(relax.op.multiply(A, relax.const(alpha, 
dtype=dtype)))
+        if alpha is not None and alpha != 1.0:
+            A = relax.op.multiply(A, relax.const(alpha, dtype=dtype))
 
         if transA:
             A = relax.op.permute_dims(A, [1, 0])
         if transB:
             B = relax.op.permute_dims(B, [1, 0])
-        Y = bb.normalize(relax.op.matmul(A, B))
+        Y = relax.op.matmul(A, B)
 
         if C is not None:
-            if beta is not None:
-                C = bb.normalize(relax.op.multiply(C, relax.const(beta, 
dtype=dtype)))
+            if beta is not None and beta != 1.0:
+                C = relax.op.multiply(C, relax.const(beta, dtype=dtype))
             Y = relax.op.add(Y, C)
 
         return Y
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 3467e5bba2..647e72f04a 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -333,8 +333,8 @@ def test_gather():
     _verify_gather([3, 3], [[0, 2]], [3, 1, 2], 1)
 
 
-@pytest.mark.parametrize("alpha", [None, 0.25])
-@pytest.mark.parametrize("beta", [None, 0.35])
+@pytest.mark.parametrize("alpha", [None, 0.25, 1.0])
+@pytest.mark.parametrize("beta", [None, 0.35, 1.0])
 @pytest.mark.parametrize("useC", [False, True])
 def test_gemm(alpha, beta, useC):
     if useC:

Reply via email to