delldu commented on a change in pull request #8443:
URL: https://github.com/apache/tvm/pull/8443#discussion_r668549131



##########
File path: tests/python/frontend/pytorch/test_forward.py
##########
@@ -3912,6 +3912,109 @@ def forward(self, x):
     verify_model(Flip(axis=-1), input_data=input)
 
 
[email protected]_gpu
+def test_forward_im2col():
+    torch.set_grad_enabled(False)
+
+    class Im2col3x3(Module):
+        def __init__(self):
+            super(Im2col3x3, self).__init__()
+
+        def forward(self, x):
+            # 
***********************************************************************************
+            #
+            # !!! DO NOT USE !!!
+            # F.unfold(x, kernel_size=3, dilation=1, padding=1, stride=1)
+            # for it broken TVM "if conditional expression" in torch script 
mode
+            #
+            # 
***********************************************************************************
+
+            return torch._C._nn.im2col(x, (3, 3), (1, 1), (1, 1), (1, 1))
+
+    class Im2col5x5(Module):
+        def __init__(self):
+            super(Im2col5x5, self).__init__()
+
+        def forward(self, x):
+            # 
***********************************************************************************
+            #
+            # !!! DO NOT USE !!!
+            # F.unfold(x, kernel_size=5, dilation=1, padding=1, stride=2)
+            # for it broken TVM "if conditional expression" in torch script 
mode
+            #
+            # 
***********************************************************************************
+
+            return torch._C._nn.im2col(x, (5, 5), (1, 1), (1, 1), (2, 2))
+
+    model = Im2col3x3()
+    input = torch.randn(2, 3, 32, 32)
+    verify_model(model, input_data=input)
+
+    verify_script_model(Im2col5x5().eval(), [(2, 3, 32, 32)], 
_get_default_vm_targets())

Review comment:
       @wyc-ruiker , Thanks.
   
   Your suggestion is good for trace mode, but not  for script mode.
   
   Let me give you a little explain. We use Im2col(3, 1, 1, 1) as an example.
   
   **1. Trace Mode -- Static Graph**
   `print(traced_model.graph)
   `
   ```
   graph(%self : __torch__.test_forward.___torch_mangle_6.Im2col,
          %x : Float(2, 3, 32, 32)):
      %3 : int = prim::Constant[value=3]()
      %4 : int = prim::Constant[value=3]()
      %5 : int[] = prim::ListConstruct(%3, %4)
      %6 : int = prim::Constant[value=1]()
      %7 : int = prim::Constant[value=1]()
      %8 : int[] = prim::ListConstruct(%6, %7)
      %9 : int = prim::Constant[value=1]()
      %10 : int = prim::Constant[value=1]()
      %11 : int[] = prim::ListConstruct(%9, %10)
      %12 : int = prim::Constant[value=1]()
      %13 : int = prim::Constant[value=1]()
      %14 : int[] = prim::ListConstruct(%12, %13)
      %15 : Float(2, 27, 1024) = aten::im2col(%x, %5, %8, %11, %14)
      return (%15)
   ```
   `print(traced_model.code)`
   ```
    def forward(self,
        x: Tensor) -> Tensor:
      _0 = torch.im2col(x, [3, 3], [1, 1], [1, 1], [1, 1])
      return _0
   ```
   
   **2. Script Mode -- Dynamic Graph**
   `print(script_module.graph)`
   ```
   graph(%self : __torch__.test_forward.Im2col,
         %x.1 : Tensor):
     %3 : int = prim::GetAttr[name="kernel_size"](%self)
     %4 : int = prim::GetAttr[name="kernel_size"](%self)
     %6 : int = prim::GetAttr[name="dilation"](%self)
     %7 : int = prim::GetAttr[name="dilation"](%self)
     %9 : int = prim::GetAttr[name="padding"](%self)
     %10 : int = prim::GetAttr[name="padding"](%self)
     %12 : int = prim::GetAttr[name="stride"](%self)
     %13 : int = prim::GetAttr[name="stride"](%self)
     %15 : int[] = prim::ListConstruct(%3, %4)
     %16 : int[] = prim::ListConstruct(%6, %7)
     %17 : int[] = prim::ListConstruct(%9, %10)
     %18 : int[] = prim::ListConstruct(%12, %13)
     %19 : Tensor = aten::im2col(%x.1, %15, %16, %17, %18)
     return (%19)
   ```
   `print(script_model.code)`
   ```
   def forward(self,
       x: Tensor) -> Tensor:
     _0 = self.kernel_size
     _1 = self.kernel_size
     _2 = self.dilation
     _3 = self.dilation
     _4 = self.padding
     _5 = self.padding
     _6 = self.stride
     _7 = self.stride
     _8 = torch.im2col(x, [_0, _1], [_2, _3], [_4, _5], [_6, _7])
     return _8
   ```
   
   **3. Root Cause**
    TVM framework could get any information about **self**, and then failure.
    
   **4. Options**
     a) Redesign fronend.from_pytorch to pre-process the model, it is good for 
future.
     b) Keep it. "transform parameter via forward" would be a better solution 
for now.
   
   We choice b), What's your opinion ?
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to