Lyken17 commented on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-963906327
Hi @vin13,
This is bit complex. Let me explain the situation. The bug was initially
found when I tried to calculate the gradients of `nn.Conv2d` with `groups`
```python
program = """
def @main(%input0: Tensor[(1, 32, 224, 224), float32],
%v0_0_weight: Tensor[(32, 1, 3, 3), float32],
%v1_conv_0_0_weight: Tensor[(32, 1, 3, 3), float32],
%v1_conv_1_weight: Tensor[(16, 32, 1, 1), float32]) -> Tensor[(1,
32, 224, 224), float32] {
%0 = nn.conv2d(%input0, %v0_0_weight, strides=[1, 1], padding=[1, 1, 1,
1], groups=32, channels=32, kernel_size=[3, 3]);
%0
}
"""
mod = parse_module(program)
mod = relay.transform.InferType()(mod)
bwd_ir = relay.transform.gradient(mod['main'], mode="first_order")
bwd_mod = tvm.IRModule.from_expr(bwd_ir)
"""
fn (%input0: Tensor[(1, 32, 224, 224), float32], %v0_0_weight: Tensor[(32,
1, 3, 3), float32], %v1_conv_0_0_weight: Tensor[(32, 1, 3, 3), float32],
%v1_conv_1_weight: Tensor[(16, 32, 1, 1), float32]) -> (Tensor[(1, 32, 224,
224), float32], (Tensor[(1, 32, 224, 224), float32], Tensor[(32, 1, 3, 3),
float32], Tensor[(32, 1, 3, 3), float32], Tensor[(16, 32, 1, 1), float32])) {
let %x = %input0;
let %x1 = zeros_like(%x);
let %x2 = %v0_0_weight;
let %x3 = zeros_like(%x2);
let %x4 = %v1_conv_0_0_weight;
let %x5 = zeros_like(%x4);
let %x6 = %v1_conv_1_weight;
let %x7 = zeros_like(%x6);
let %x8 = nn.conv2d(%x, %x2, padding=[1, 1, 1, 1], groups=32, channels=32,
kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 224, 224), float32] */;
let %x9 = zeros_like(%x8);
%0 = ones_like(%x8);
%1 = nn.conv2d_transpose(%0, %x2, padding=[1, 1, 1, 1], groups=32);
%9 = (
let %x10 = add(%x1, %1);
%2 = tile(%0, reps=[1, 1, 1, 1]);
%3 = reshape(%x, newshape=[1, -1, 0, 0]);
%4 = reshape(%2, newshape=[-1, 1, 0, 0]);
%5 = nn.conv2d(%3, %4, padding=[1, 1, 1, 1], groups=32);
%6 = reshape(%5, newshape=[1, 1, 32, 3, 3]);
%7 = sum(%6, axis=[0]);
%8 = transpose(%7, axes=[1, 0, 2, 3]);
let %x11 = add(%x3, %8);
(%x10, %x11, %x5, %x7)
);
(%x8, %9)
}
"""
```
It is shown that `%1 = nn.conv2d_transpose(%0, %x2, padding=[1, 1, 1, 1],
groups=32);` is transformed to `%1 = nn.conv2d_transpose(%0, %x2, padding=[1,
1, 1, 1], groups=32);`. The shape of `%2` is (32, 1, 3, 3), which is `OIHW` for
Conv2d and `IOHW` for Conv2dTransposed. This is consistent with PyTorch
[[1](https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html),[2](https://pytorch.org/docs/stable/generated/torch.nn.functional.conv_transpose2d.html#torch.nn.functional.conv_transpose2d)]
and
PaddlePaddle[[1](https://www.paddlepaddle.org.cn/documentation/docs/en/1.8/api/dygraph/Conv2DTranspose.html)].
If the original purpose of `topi.nn.conv2d_transpose` is to use `OIHW` rather
than `IOHW`, then the primal gradients registration is wrong.
From my personal perspective, I prefer `IOHW` more since `conv2dtranspose`
should have an `transposed` weight and this make tvm consistent pytorch. Is
there any specific reason (e.g., performance tuning) that makes `OIHW` better
than `IOHW`?
In either case, the `groups` support for conv2dtranspose is missing and
should added.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]