delldu commented on a change in pull request #8443:
URL: https://github.com/apache/tvm/pull/8443#discussion_r669321288
##########
File path: tests/python/frontend/pytorch/test_forward.py
##########
@@ -3912,6 +3912,109 @@ def forward(self, x):
verify_model(Flip(axis=-1), input_data=input)
[email protected]_gpu
+def test_forward_im2col():
+ torch.set_grad_enabled(False)
+
+ class Im2col3x3(Module):
+ def __init__(self):
+ super(Im2col3x3, self).__init__()
+
+ def forward(self, x):
+ #
***********************************************************************************
+ #
+ # !!! DO NOT USE !!!
+ # F.unfold(x, kernel_size=3, dilation=1, padding=1, stride=1)
+ # for it broken TVM "if conditional expression" in torch script
mode
+ #
+ #
***********************************************************************************
+
+ return torch._C._nn.im2col(x, (3, 3), (1, 1), (1, 1), (1, 1))
+
+ class Im2col5x5(Module):
+ def __init__(self):
+ super(Im2col5x5, self).__init__()
+
+ def forward(self, x):
+ #
***********************************************************************************
+ #
+ # !!! DO NOT USE !!!
+ # F.unfold(x, kernel_size=5, dilation=1, padding=1, stride=2)
+ # for it broken TVM "if conditional expression" in torch script
mode
+ #
+ #
***********************************************************************************
+
+ return torch._C._nn.im2col(x, (5, 5), (1, 1), (1, 1), (2, 2))
+
+ model = Im2col3x3()
+ input = torch.randn(2, 3, 32, 32)
+ verify_model(model, input_data=input)
+
+ verify_script_model(Im2col5x5().eval(), [(2, 3, 32, 32)],
_get_default_vm_targets())
Review comment:
Because we have a model with control flow, meet missing [aten::im2col]
from F.unfold in practice.
That's why we develop im2col and write a test unit with script mode . Simple
patches as below:
```
def forward(self, x):
x = self.simple_forward(x)
B = x.size(0)
C = x.size(1)
H = x.size(2)
W = x.size(3)
# x = F.unfold(x, 3, dilation=1, padding=1, stride=1)
x = torch._C._nn.im2col(x, (3, 3), (1, 1), (1, 1), (1, 1))
return x.view(B, C * 9, H, W)
```
```
bs: int = 65536
preds = []
start: int = 0
while start < n:
stop: int = start + bs
if stop > n:
stop = n
pred = ...
preds += [pred.view(stop - start, 1, 3)]
start = stop
y = torch.stack(preds, dim=0)
return y
```
First patch, we need aten::im2col, second patch, we real need control flow.
We try to remove while in second patches, but memory cost beyond 64G !
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]