wyc-ruiker commented on a change in pull request #8443:
URL: https://github.com/apache/tvm/pull/8443#discussion_r668389145



##########
File path: tests/python/frontend/pytorch/test_forward.py
##########
@@ -3912,6 +3912,109 @@ def forward(self, x):
     verify_model(Flip(axis=-1), input_data=input)
 
 
[email protected]_gpu
+def test_forward_im2col():
+    torch.set_grad_enabled(False)
+
+    class Im2col3x3(Module):
+        def __init__(self):
+            super(Im2col3x3, self).__init__()
+
+        def forward(self, x):
+            # 
***********************************************************************************
+            #
+            # !!! DO NOT USE !!!
+            # F.unfold(x, kernel_size=3, dilation=1, padding=1, stride=1)
+            # for it broken TVM "if conditional expression" in torch script 
mode
+            #
+            # 
***********************************************************************************
+
+            return torch._C._nn.im2col(x, (3, 3), (1, 1), (1, 1), (1, 1))
+
+    class Im2col5x5(Module):
+        def __init__(self):
+            super(Im2col5x5, self).__init__()
+
+        def forward(self, x):
+            # 
***********************************************************************************
+            #
+            # !!! DO NOT USE !!!
+            # F.unfold(x, kernel_size=5, dilation=1, padding=1, stride=2)
+            # for it broken TVM "if conditional expression" in torch script 
mode
+            #
+            # 
***********************************************************************************
+
+            return torch._C._nn.im2col(x, (5, 5), (1, 1), (1, 1), (2, 2))
+
+    model = Im2col3x3()
+    input = torch.randn(2, 3, 32, 32)
+    verify_model(model, input_data=input)
+
+    verify_script_model(Im2col5x5().eval(), [(2, 3, 32, 32)], 
_get_default_vm_targets())

Review comment:
       Hi @delldu, the other parts LGTM. But the unit test I think we can 
rewrite it more concisely and support more situations by pass some params. Here 
is a simple example.
   ```suggestion
       class Im2col(Module):
           def __init__(self, kernel_size, dilation, padding, stride):
               super(Im2col, self).__init__()
               self.kernel_size = kernel_size
               self.dilation = dilation
               self.padding = padding
               self.stride = stride
   
           def forward(self, x):
               # 
***********************************************************************************
               #
               # !!! DO NOT USE !!!
               # F.unfold(x, kernel_size=3, dilation=1, padding=1, stride=1)
               # for it broken TVM "if conditional expression" in torch script 
mode
               #
               # 
***********************************************************************************
   
               return torch._C._nn.im2col(x, (kernel_size, kernel_size), (1, 
1), (1, 1), (1, 1))
   
       input = torch.randn(2, 3, 32, 32)
   
       verify_model(Im2col(3, 1, 1, 1), input_data=input)
       verify_model(Im2col(5, 1, 1, 1), input_data=input)
       verify_model(Im2col(3, 1, 2, 1), input_data=input)
       verify_model(Im2col(5, 1, 2, 1), input_data=input)
       verify_script_model(Im2col(3, 1, 1, 1).eval(), [(2, 3, 32, 32)], 
_get_default_vm_targets())
       verify_script_model(Im2col(5, 1, 1, 1).eval(), [(2, 3, 32, 32)], 
_get_default_vm_targets())
       verify_script_model(Im2col(3, 1, 2, 1).eval(), [(2, 3, 32, 32)], 
_get_default_vm_targets())
       verify_script_model(Im2col(5, 1, 2, 1).eval(), [(2, 3, 32, 32)], 
_get_default_vm_targets())
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to