This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new d090b8e  [PYTORCH]Padding support (#5638)
d090b8e is described below

commit d090b8e20e827906fff80c0484ddbe34169a5100
Author: Samuel <siju.sam...@huawei.com>
AuthorDate: Fri May 22 02:14:17 2020 +0530

    [PYTORCH]Padding support (#5638)
---
 python/tvm/relay/frontend/pytorch.py          | 27 +++++++++++++--
 tests/python/frontend/pytorch/test_forward.py | 49 +++++++++++++++++++++++++++
 2 files changed, 73 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 28703da..cc7cd48 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1342,10 +1342,31 @@ def _none():
 def _pad():
     def _impl(inputs, input_types):
         data = inputs[0]
-        padding = inputs[1]
-        pad_width = list(zip(padding, padding))
+        if isinstance(inputs[1], list):
+            pad_list = inputs[1]
+        else:
+            pad_list = list(_infer_shape(inputs[1]))
+
+        # initialize paddings based on input len
+        pad_len = len(_infer_shape(data)) * 2
+        paddings = [0] * pad_len
+
+        if len(pad_list) >= 2:
+            paddings[-1] = pad_list[1]
+            paddings[-2] = pad_list[0]
+        if len(pad_list) >= 4:
+            paddings[-3] = pad_list[3]
+            paddings[-4] = pad_list[2]
+        if len(pad_list) >= 6:
+            paddings[-5] = pad_list[5]
+            paddings[-6] = pad_list[4]
+
+        # group into tuple of 2 ints
+        paddings = [paddings[i:i + 2] for i in range(0, len(paddings), 2)]
+
         pad_value = inputs[2]
-        return _op.nn.pad(data, pad_width, pad_value)
+
+        return _op.nn.pad(data, paddings, pad_value)
     return _impl
 
 
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index f1543f0..85928bf 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1020,6 +1020,50 @@ def test_adaptive_pool3d():
         verify_model(torch.nn.AdaptiveMaxPool3d((7, 8, 9)).eval(), inp)
 
 
+def test_forward_functional_pad():
+    torch.set_grad_enabled(False)
+    pad = (0, 0)
+    class Pad1(Module):
+        def forward(self, *args):
+            return torch.nn.functional.pad(args[0], pad, "constant", 0)
+
+    input_data = torch.rand((3, 3, 4, 2))
+    pad = (1, 1)
+    verify_model(Pad1().float().eval(), input_data=input_data)
+
+    pad = (1, 1, 2, 2)
+    verify_model(Pad1().float().eval(), input_data=input_data)
+
+    pad = (0, 1, 2, 1, 3, 3)
+    verify_model(Pad1().float().eval(), input_data=input_data)
+
+
+def test_forward_zero_pad2d():
+    inp = torch.rand((1, 1, 3, 3))
+    verify_model(torch.nn.ZeroPad2d(2).eval(), inp)
+    verify_model(torch.nn.ZeroPad2d((1, 1, 2, 0)).eval(), inp)
+
+
+def test_forward_constant_pad1d():
+    inp = torch.rand((1, 2, 4))
+    verify_model(torch.nn.ConstantPad2d(2, 3.5).eval(), inp)
+
+    inp = torch.rand((1, 2, 3))
+    verify_model(torch.nn.ConstantPad2d((3, 1), 3.5).eval(), inp)
+
+
+def test_forward_constant_pad2d():
+    inp = torch.rand((1, 2, 2, 2))
+    verify_model(torch.nn.ConstantPad2d(2, 3.5).eval(), inp)
+    verify_model(torch.nn.ConstantPad2d((3, 0, 2, 1), 3.5).eval(), inp)
+
+
+def test_forward_constant_pad3d():
+    inp = torch.rand((1, 3, 2, 2, 2))
+    verify_model(torch.nn.ConstantPad3d(3, 3.5).eval(), inp)
+    verify_model(torch.nn.ConstantPad3d((3, 4, 5, 6, 0, 1), 3.5).eval(), inp)
+
+
 def test_forward_reflection_pad2d():
     inp = torch.rand((1, 1, 3, 3))
     verify_model(torch.nn.ReflectionPad2d(2).eval(), inp)
@@ -2200,6 +2244,11 @@ if __name__ == "__main__":
     test_upsample()
     test_forward_upsample3d()
     test_to()
+    test_forward_functional_pad()
+    test_forward_zero_pad2d()
+    test_forward_constant_pad1d()
+    test_forward_constant_pad2d()
+    test_forward_constant_pad3d()
     test_forward_reflection_pad2d()
     test_adaptive_pool3d()
     test_conv3d()

Reply via email to