This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new c0b3949765 [Unity][Frontend] Fix torch addmm op param alpha&beta 
(#15497)
c0b3949765 is described below

commit c0b3949765b4e768af618df3579f4c2b13d7e613
Author: HLearning <[email protected]>
AuthorDate: Tue Aug 8 10:59:08 2023 +0800

    [Unity][Frontend] Fix torch addmm op param alpha&beta (#15497)
    
    This PR fixes the torch translator of `addmm` op on alpha and beta handling.
---
 python/tvm/relax/frontend/torch/fx_translator.py | 19 ++++++++++++--
 tests/python/relax/test_frontend_from_fx.py      | 33 ++++++++++++++++++++++--
 2 files changed, 48 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 6ff013fc87..fde31af601 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -464,8 +464,23 @@ class TorchFXImporter:
         x = self.env[node.args[0]]
         y = self.env[node.args[1]]
         z = self.env[node.args[2]]
-        matmul = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, 
out_dtype="float32"))
-        return self.block_builder.emit(relax.op.add(x, matmul))
+        alpha = node.kwargs["alpha"] if "alpha" in node.kwargs else 1
+        beta = node.kwargs["beta"] if "beta" in node.kwargs else 1
+
+        res = None
+        if alpha != 0:
+            res = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, 
out_dtype="float32"))
+            if alpha != 1:
+                dtype = res.struct_info.dtype
+                res = self.block_builder.emit(relax.op.multiply(res, 
relax.const(alpha, dtype)))
+        if beta != 0:
+            dtype = x.struct_info.dtype
+            if beta != 1:
+                bias = self.block_builder.emit(relax.op.multiply(x, 
relax.const(beta, dtype)))
+            else:
+                bias = x
+            res = bias if res is None else 
self.block_builder.emit(relax.op.add(bias, res))
+        return res
 
     def _baddbmm(self, node: fx.node.Node) -> relax.Var:
         x = self.env[node.args[0]]
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index a7b7770d33..8ac1f718e8 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1836,10 +1836,20 @@ def test_addmm():
         ([10, 10], "float32"),
     ]
 
-    class Addmm(Module):
+    class Addmm1(Module):
+        def __init__(self):
+            super().__init__()
+
         def forward(self, x1, x2, x3):
             return torch.addmm(x1, x2, x3)
 
+    class Addmm2(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x1, x2, x3):
+            return torch.addmm(x1, x2, x3, beta=0.8, alpha=0.5)
+
     @tvm.script.ir_module
     class expected1:
         @R.function
@@ -1856,7 +1866,26 @@ def test_addmm():
                 R.output(gv)
             return gv
 
-    verify_model(Addmm(), input_info, {}, expected1)
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            x1: R.Tensor((10, 10), dtype="float32"),
+            x2: R.Tensor((10, 10), dtype="float32"),
+            x3: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tensor((10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, 
out_dtype="float32")
+                lv1: R.Tensor((10, 10), dtype="float32") = R.multiply(lv, 
R.const(0.5, "float32"))
+                lv2: R.Tensor((10, 10), dtype="float32") = R.multiply(x1, 
R.const(0.8, "float32"))
+                lv3: R.Tensor((10, 10), dtype="float32") = R.add(lv2, lv1)
+                gv: R.Tensor((10, 10), dtype="float32") = lv3
+                R.output(gv)
+            return gv
+
+    verify_model(Addmm1(), input_info, {}, expected1)
+    verify_model(Addmm2(), input_info, {}, expected2)
 
 
 def test_split():

Reply via email to