This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 11815b8 Don't multiply by constant 1 uselessly in dense (#5911)
11815b8 is described below
commit 11815b8d8fd9255e2d5ea1fc9ada98222228d462
Author: Thomas Viehmann <[email protected]>
AuthorDate: Wed Jun 24 13:49:43 2020 +0200
Don't multiply by constant 1 uselessly in dense (#5911)
---
python/tvm/relay/frontend/pytorch.py | 4 ++--
tests/python/frontend/pytorch/test_forward.py | 19 +++++++++++++++++++
2 files changed, 21 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 9237303..84b0907 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -995,11 +995,11 @@ def _dense():
beta = inputs[3]
alpha = inputs[4]
- if not isinstance(alpha, _expr.Expr):
+ if not isinstance(alpha, _expr.Expr) and alpha != 1:
alpha = _create_typed_const(alpha, data_type)
data *= alpha
- if not isinstance(beta, _expr.Expr):
+ if not isinstance(beta, _expr.Expr) and beta != 1:
beta = _create_typed_const(beta, data_type)
weight *= beta
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 12d1260..0694fa5 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -33,6 +33,18 @@ from tvm.relay.testing.config import ctx_list
sys.setrecursionlimit(10000)
+def list_ops(expr):
+ class OpLister(tvm.relay.ExprVisitor):
+ def visit_op(self, expr):
+ if expr not in self.node_set:
+ self.node_list.append(expr)
+ return super().visit_op(expr)
+ def list_nodes(self, expr):
+ self.node_set = {}
+ self.node_list = []
+ self.visit(expr)
+ return self.node_list
+ return OpLister().list_nodes(expr)
def assert_shapes_match(tru, est):
if tru.shape != est.shape:
@@ -1047,6 +1059,13 @@ def test_forward_dense():
verify_model(Dense1().float().eval(), input_data=input_data)
verify_model(Dense2().float().eval(), input_data=input_data)
+ trace = torch.jit.trace(Dense1(), [input_data])
+ mod, params = relay.frontend.from_pytorch(
+ trace,
+ [('input', input_shape)],
+ )
+ assert not any([op.name == "multiply" for op in list_ops(mod['main'])])
+
def test_forward_dropout():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]