This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new c0b3949765 [Unity][Frontend] Fix torch addmm op param alpha&beta
(#15497)
c0b3949765 is described below
commit c0b3949765b4e768af618df3579f4c2b13d7e613
Author: HLearning <[email protected]>
AuthorDate: Tue Aug 8 10:59:08 2023 +0800
[Unity][Frontend] Fix torch addmm op param alpha&beta (#15497)
This PR fixes the torch translator of `addmm` op on alpha and beta handling.
---
python/tvm/relax/frontend/torch/fx_translator.py | 19 ++++++++++++--
tests/python/relax/test_frontend_from_fx.py | 33 ++++++++++++++++++++++--
2 files changed, 48 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 6ff013fc87..fde31af601 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -464,8 +464,23 @@ class TorchFXImporter:
x = self.env[node.args[0]]
y = self.env[node.args[1]]
z = self.env[node.args[2]]
- matmul = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z,
out_dtype="float32"))
- return self.block_builder.emit(relax.op.add(x, matmul))
+ alpha = node.kwargs["alpha"] if "alpha" in node.kwargs else 1
+ beta = node.kwargs["beta"] if "beta" in node.kwargs else 1
+
+ res = None
+ if alpha != 0:
+ res = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z,
out_dtype="float32"))
+ if alpha != 1:
+ dtype = res.struct_info.dtype
+ res = self.block_builder.emit(relax.op.multiply(res,
relax.const(alpha, dtype)))
+ if beta != 0:
+ dtype = x.struct_info.dtype
+ if beta != 1:
+ bias = self.block_builder.emit(relax.op.multiply(x,
relax.const(beta, dtype)))
+ else:
+ bias = x
+ res = bias if res is None else
self.block_builder.emit(relax.op.add(bias, res))
+ return res
def _baddbmm(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index a7b7770d33..8ac1f718e8 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1836,10 +1836,20 @@ def test_addmm():
([10, 10], "float32"),
]
- class Addmm(Module):
+ class Addmm1(Module):
+ def __init__(self):
+ super().__init__()
+
def forward(self, x1, x2, x3):
return torch.addmm(x1, x2, x3)
+ class Addmm2(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x1, x2, x3):
+ return torch.addmm(x1, x2, x3, beta=0.8, alpha=0.5)
+
@tvm.script.ir_module
class expected1:
@R.function
@@ -1856,7 +1866,26 @@ def test_addmm():
R.output(gv)
return gv
- verify_model(Addmm(), input_info, {}, expected1)
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ x1: R.Tensor((10, 10), dtype="float32"),
+ x2: R.Tensor((10, 10), dtype="float32"),
+ x3: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3,
out_dtype="float32")
+ lv1: R.Tensor((10, 10), dtype="float32") = R.multiply(lv,
R.const(0.5, "float32"))
+ lv2: R.Tensor((10, 10), dtype="float32") = R.multiply(x1,
R.const(0.8, "float32"))
+ lv3: R.Tensor((10, 10), dtype="float32") = R.add(lv2, lv1)
+ gv: R.Tensor((10, 10), dtype="float32") = lv3
+ R.output(gv)
+ return gv
+
+ verify_model(Addmm1(), input_info, {}, expected1)
+ verify_model(Addmm2(), input_info, {}, expected2)
def test_split():