This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ab75b58117 [Bug][Relay] fix relay frontend pytorch op addmm bug
(#15294)
ab75b58117 is described below
commit ab75b58117672cfaa70ab59b85de9a961fd8de14
Author: HLearning <[email protected]>
AuthorDate: Thu Jul 20 22:01:23 2023 +0800
[Bug][Relay] fix relay frontend pytorch op addmm bug (#15294)
* Update pytorch.py
fix relay frontend pytorch op: addmm
calculation formula error.
bug:
out = input + alpha * beta * mat1 @ mat2
fix bug:
out = beta * input + alpha * mat1 @ mat2
* add relay frontend pytorch addmm op test
* fix relay frontend pytorch op addmm
* add relay frontend pytorch addmm op test
---
python/tvm/relay/frontend/pytorch.py | 16 +++++++---------
tests/python/frontend/pytorch/test_forward.py | 12 ++++++++++++
2 files changed, 19 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 5e4d755996..37f32e3c02 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1549,24 +1549,22 @@ class PyTorchOpConverter:
def addmm(self, inputs, input_types):
input_mat = inputs[0]
mat1 = inputs[1]
- data_type = input_types[1]
mat2 = inputs[2]
-
beta = inputs[3]
alpha = inputs[4]
+ data_type = input_types[1]
+
+ transposed_mat2 = _op.transform.transpose(mat2, axes=[1, 0])
+ units = self.infer_shape(transposed_mat2)[0]
+ dense_out = _op.nn.dense(mat1, transposed_mat2, units=units)
if not isinstance(alpha, _expr.Expr) and alpha != 1:
alpha = _create_typed_const(alpha, data_type)
- mat1 *= alpha
+ dense_out *= alpha
if not isinstance(beta, _expr.Expr) and beta != 1:
beta = _create_typed_const(beta, data_type)
- mat2 *= beta
-
- transposed_mat2 = _op.transform.transpose(mat2, axes=[1, 0])
-
- units = self.infer_shape(transposed_mat2)[0]
- dense_out = _op.nn.dense(mat1, transposed_mat2, units=units)
+ input_mat *= beta
return dense_out + input_mat
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 83930d1ea8..5b02d26d5d 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -5260,6 +5260,18 @@ def test_weight_norm():
verify_model(linear_wn.eval().float(), input_data_linear)
[email protected]_gpu
+def test_addmm():
+ def test_fn(alpha, beta):
+ return lambda inp, batch1, batch2: torch.addmm(inp, batch1, batch2,
beta=beta, alpha=alpha)
+
+ M = torch.randn(3, 5)
+ batch1 = torch.randn(3, 4)
+ batch2 = torch.randn(4, 5)
+
+ verify_model(test_fn(0.4, 0.8), [M, batch1, batch2])
+
+
@tvm.testing.uses_gpu
def test_baddbmm():
def test_fn(alpha, beta):