Lyken17 edited a comment on pull request #9465:
URL: https://github.com/apache/tvm/pull/9465#issuecomment-963906327
Hi @vin13,
The situation is bit complex: The bug was initially found when I tried to
calculate the gradients of `nn.Conv2d` with `groups`
```python
program = """
def @main(%input0: Tensor[(1, 32, 224, 224), float32],
%v0_0_weight: Tensor[(32, 1, 3, 3), float32],
%v1_conv_0_0_weight: Tensor[(32, 1, 3, 3), float32],
%v1_conv_1_weight: Tensor[(16, 32, 1, 1), float32]) -> Tensor[(1,
32, 224, 224), float32] {
%0 = nn.conv2d(%input0, %v0_0_weight, strides=[1, 1], padding=[1, 1, 1,
1], groups=32, channels=32, kernel_size=[3, 3]);
%0
}
"""
mod = parse_module(program)
mod = relay.transform.InferType()(mod)
bwd_ir = relay.transform.gradient(mod['main'], mode="first_order")
bwd_mod = tvm.IRModule.from_expr(bwd_ir)
"""
fn (%input0: Tensor[(1, 32, 224, 224), float32], %v0_0_weight: Tensor[(32,
1, 3, 3), float32], %v1_conv_0_0_weight: Tensor[(32, 1, 3, 3), float32],
%v1_conv_1_weight: Tensor[(16, 32, 1, 1), float32]) -> (Tensor[(1, 32, 224,
224), float32], (Tensor[(1, 32, 224, 224), float32], Tensor[(32, 1, 3, 3),
float32], Tensor[(32, 1, 3, 3), float32], Tensor[(16, 32, 1, 1), float32])) {
let %x = %input0;
let %x1 = zeros_like(%x);
let %x2 = %v0_0_weight;
let %x3 = zeros_like(%x2);
let %x4 = %v1_conv_0_0_weight;
let %x5 = zeros_like(%x4);
let %x6 = %v1_conv_1_weight;
let %x7 = zeros_like(%x6);
let %x8 = nn.conv2d(%x, %x2, padding=[1, 1, 1, 1], groups=32, channels=32,
kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 224, 224), float32] */;
let %x9 = zeros_like(%x8);
%0 = ones_like(%x8);
%1 = nn.conv2d_transpose(%0, %x2, padding=[1, 1, 1, 1], groups=32);
%9 = (
let %x10 = add(%x1, %1);
%2 = tile(%0, reps=[1, 1, 1, 1]);
%3 = reshape(%x, newshape=[1, -1, 0, 0]);
%4 = reshape(%2, newshape=[-1, 1, 0, 0]);
%5 = nn.conv2d(%3, %4, padding=[1, 1, 1, 1], groups=32);
%6 = reshape(%5, newshape=[1, 1, 32, 3, 3]);
%7 = sum(%6, axis=[0]);
%8 = transpose(%7, axes=[1, 0, 2, 3]);
let %x11 = add(%x3, %8);
(%x10, %x11, %x5, %x7)
);
(%x8, %9)
}
"""
```
It is shown that `%1 = nn.conv2d_transpose(%0, %x2, padding=[1, 1, 1, 1],
groups=32);` is transformed to `%1 = nn.conv2d_transpose(%0, %x2, padding=[1,
1, 1, 1], groups=32);`. The shape of `%2` is (32, 1, 3, 3), which is `OIHW` for
Conv2d and `IOHW` for Conv2dTransposed. This is consistent with PyTorch
[[1](https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html),[2](https://pytorch.org/docs/stable/generated/torch.nn.functional.conv_transpose2d.html#torch.nn.functional.conv_transpose2d)]
and
PaddlePaddle[[1](https://www.paddlepaddle.org.cn/documentation/docs/en/1.8/api/dygraph/Conv2DTranspose.html)].
However, if the original purpose of `topi.nn.conv2d_transpose` is to use
`OIHW` rather than `IOHW`, then the primal gradients registration is wrong.
From my personal perspective, I prefer `IOHW` more since `conv2dtranspose`
should have an `transposed` weight and this make tvm consistent pytorch. Is
there any specific reason (e.g., performance tuning) that makes `OIHW` better
than `IOHW`?
There are two ways to fix the isssue
* The first is to refactor `topi.nn.conv2d_trasnpose` to make weight `OIHW`
thus the gradients calculcation does not need to be changed and users can
benefit from the consistent layout and APIs with other frameworks like PyTorch.
* The second is to re-write `nn.conv2d`'s primal gradient to make it
compatible with `IOHW` weight layout. The advantage of this is we only need to
update one code but users may get confused in the future because of the
inconsistent layout.
In either cases, the `groups` support for conv2dtranspose is missing and
should be added.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]