This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new c1eb315  [PYTORCH]Std op without specified dimensions support (#6226)
c1eb315 is described below

commit c1eb31566ac7321809f4b9734df97edf378573f6
Author: shiwenloong <[email protected]>
AuthorDate: Fri Aug 7 08:55:46 2020 +0800

    [PYTORCH]Std op without specified dimensions support (#6226)
---
 python/tvm/relay/frontend/pytorch.py          | 11 ++++++++---
 tests/python/frontend/pytorch/test_forward.py |  5 +++++
 2 files changed, 13 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 3dfdb2f..bbc684e 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1253,9 +1253,14 @@ def _frobenius_norm():
 def _std():
     def _impl(inputs, input_types):
         data = inputs[0]
-        axis = list(_infer_shape(inputs[1]))
-        keepdims = bool(inputs[3])
-        unbiased = bool(inputs[2])
+        if len(inputs) == 2:
+            axis = None
+            keepdims = False
+            unbiased = bool(inputs[1])
+        else:
+            axis = list(_infer_shape(inputs[1]))
+            keepdims = bool(inputs[3])
+            unbiased = bool(inputs[2])
 
         if unbiased:
             msg = "Currently only supports standard-deviation calculated via 
the biased "\
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index e370cd5..3c9dfb1 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1869,12 +1869,17 @@ def test_forward_std():
         def forward(self, *args):
             return args[0].std(dim=(2,3), keepdim=False, unbiased=False)
 
+    class Std6(Module):
+        def forward(self, *args):
+            return args[0].std(unbiased=False)
+
     input_data = torch.rand(input_shape).float()
     verify_model(Std1().float().eval(), input_data=input_data)
     verify_model(Std2().float().eval(), input_data=input_data)
     verify_model(Std3().float().eval(), input_data=input_data)
     verify_model(Std4().float().eval(), input_data=input_data)
     verify_model(Std5().float().eval(), input_data=input_data)
+    verify_model(Std6().float().eval(), input_data=input_data)
 
 
 def test_forward_variance():

Reply via email to