delldu commented on a change in pull request #8443:
URL: https://github.com/apache/tvm/pull/8443#discussion_r668549131
##########
File path: tests/python/frontend/pytorch/test_forward.py
##########
@@ -3912,6 +3912,109 @@ def forward(self, x):
verify_model(Flip(axis=-1), input_data=input)
[email protected]_gpu
+def test_forward_im2col():
+ torch.set_grad_enabled(False)
+
+ class Im2col3x3(Module):
+ def __init__(self):
+ super(Im2col3x3, self).__init__()
+
+ def forward(self, x):
+ #
***********************************************************************************
+ #
+ # !!! DO NOT USE !!!
+ # F.unfold(x, kernel_size=3, dilation=1, padding=1, stride=1)
+ # for it broken TVM "if conditional expression" in torch script
mode
+ #
+ #
***********************************************************************************
+
+ return torch._C._nn.im2col(x, (3, 3), (1, 1), (1, 1), (1, 1))
+
+ class Im2col5x5(Module):
+ def __init__(self):
+ super(Im2col5x5, self).__init__()
+
+ def forward(self, x):
+ #
***********************************************************************************
+ #
+ # !!! DO NOT USE !!!
+ # F.unfold(x, kernel_size=5, dilation=1, padding=1, stride=2)
+ # for it broken TVM "if conditional expression" in torch script
mode
+ #
+ #
***********************************************************************************
+
+ return torch._C._nn.im2col(x, (5, 5), (1, 1), (1, 1), (2, 2))
+
+ model = Im2col3x3()
+ input = torch.randn(2, 3, 32, 32)
+ verify_model(model, input_data=input)
+
+ verify_script_model(Im2col5x5().eval(), [(2, 3, 32, 32)],
_get_default_vm_targets())
Review comment:
@wyc-ruiker , Thanks.
Your suggestion is good for trace mode, but not for script mode.
Let me give you a little explain. We use Im2col(3, 1, 1, 1) as an example.
**1. Trace Mode -- Static Graph**
`print(traced_model.graph)
`
```
graph(%self : __torch__.test_forward.___torch_mangle_6.Im2col,
%x : Float(2, 3, 32, 32)):
%3 : int = prim::Constant[value=3]()
%4 : int = prim::Constant[value=3]()
%5 : int[] = prim::ListConstruct(%3, %4)
%6 : int = prim::Constant[value=1]()
%7 : int = prim::Constant[value=1]()
%8 : int[] = prim::ListConstruct(%6, %7)
%9 : int = prim::Constant[value=1]()
%10 : int = prim::Constant[value=1]()
%11 : int[] = prim::ListConstruct(%9, %10)
%12 : int = prim::Constant[value=1]()
%13 : int = prim::Constant[value=1]()
%14 : int[] = prim::ListConstruct(%12, %13)
%15 : Float(2, 27, 1024) = aten::im2col(%x, %5, %8, %11, %14)
return (%15)
```
`print(traced_model.code)`
```
def forward(self,
x: Tensor) -> Tensor:
_0 = torch.im2col(x, [3, 3], [1, 1], [1, 1], [1, 1])
return _0
```
**2. Script Mode -- Dynamic Graph**
`print(script_module.graph)`
```
graph(%self : __torch__.test_forward.Im2col,
%x.1 : Tensor):
%3 : int = prim::GetAttr[name="kernel_size"](%self)
%4 : int = prim::GetAttr[name="kernel_size"](%self)
%6 : int = prim::GetAttr[name="dilation"](%self)
%7 : int = prim::GetAttr[name="dilation"](%self)
%9 : int = prim::GetAttr[name="padding"](%self)
%10 : int = prim::GetAttr[name="padding"](%self)
%12 : int = prim::GetAttr[name="stride"](%self)
%13 : int = prim::GetAttr[name="stride"](%self)
%15 : int[] = prim::ListConstruct(%3, %4)
%16 : int[] = prim::ListConstruct(%6, %7)
%17 : int[] = prim::ListConstruct(%9, %10)
%18 : int[] = prim::ListConstruct(%12, %13)
%19 : Tensor = aten::im2col(%x.1, %15, %16, %17, %18)
return (%19)
```
`print(script_model.code)`
```
def forward(self,
x: Tensor) -> Tensor:
_0 = self.kernel_size
_1 = self.kernel_size
_2 = self.dilation
_3 = self.dilation
_4 = self.padding
_5 = self.padding
_6 = self.stride
_7 = self.stride
_8 = torch.im2col(x, [_0, _1], [_2, _3], [_4, _5], [_6, _7])
return _8
```
**3. Root Cause**
TVM framework could not get any information about **self**, and then
failure.
**4. Options**
a) Redesign fronend.from_pytorch to pre-process the model, it is good for
future.
b) Keep it. "transform parameter via forward" would be a better solution
for now.
We choice b), What's your opinion ?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]