This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new c1eb315 [PYTORCH]Std op without specified dimensions support (#6226)
c1eb315 is described below
commit c1eb31566ac7321809f4b9734df97edf378573f6
Author: shiwenloong <[email protected]>
AuthorDate: Fri Aug 7 08:55:46 2020 +0800
[PYTORCH]Std op without specified dimensions support (#6226)
---
python/tvm/relay/frontend/pytorch.py | 11 ++++++++---
tests/python/frontend/pytorch/test_forward.py | 5 +++++
2 files changed, 13 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 3dfdb2f..bbc684e 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1253,9 +1253,14 @@ def _frobenius_norm():
def _std():
def _impl(inputs, input_types):
data = inputs[0]
- axis = list(_infer_shape(inputs[1]))
- keepdims = bool(inputs[3])
- unbiased = bool(inputs[2])
+ if len(inputs) == 2:
+ axis = None
+ keepdims = False
+ unbiased = bool(inputs[1])
+ else:
+ axis = list(_infer_shape(inputs[1]))
+ keepdims = bool(inputs[3])
+ unbiased = bool(inputs[2])
if unbiased:
msg = "Currently only supports standard-deviation calculated via
the biased "\
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index e370cd5..3c9dfb1 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1869,12 +1869,17 @@ def test_forward_std():
def forward(self, *args):
return args[0].std(dim=(2,3), keepdim=False, unbiased=False)
+ class Std6(Module):
+ def forward(self, *args):
+ return args[0].std(unbiased=False)
+
input_data = torch.rand(input_shape).float()
verify_model(Std1().float().eval(), input_data=input_data)
verify_model(Std2().float().eval(), input_data=input_data)
verify_model(Std3().float().eval(), input_data=input_data)
verify_model(Std4().float().eval(), input_data=input_data)
verify_model(Std5().float().eval(), input_data=input_data)
+ verify_model(Std6().float().eval(), input_data=input_data)
def test_forward_variance():