JCBrouwer opened a new issue #10223:
URL: https://github.com/apache/tvm/issues/10223
I'm trying to convert a PyTorch model which makes use of
`torch.nn.functional.conv_transpose2d` and am running into issues with my
converter to the corresponding `tvm.relay.op.nn.conv2d_transpose` operation.
I've done a little monkey patching on the PyTorchOpConverter as the
operations that `torch.nn.functional.conv2d/conv_transpose2d` trace to
(`aten::conv2d` and `aten::conv_transpose2d`) aren't covered by default. I've
added functions to convert each one to the PyTorchOpConverter so that I have
access to `self.infer_shape(weight)` in the functions as follows:
<details>
<summary>Converter implementation</summary>
```py
class MyPyTorchOpConverter(PyTorchOpConverter):
def __init__(self, prelude, default_dtype):
super().__init__(prelude, default_dtype)
self.update_convert_map(
{"aten::conv2d": self.convert_conv2d, "aten::conv_transpose2d":
self.convert_conv_transpose2d}
)
def convert_conv2d(self, inputs, input_types):
data = inputs[0]
weight = inputs[1]
bias = inputs[2]
strides = inputs[3]
padding = inputs[4]
dilation = inputs[5]
groups = inputs[6]
channels, input_channels, kh, kw = self.infer_shape(weight) # OIHW
if groups > 1 and input_channels == 1:
channel_multiplier = channels // groups
new_weight_shape = (groups, channel_multiplier, kh, kw)
weight = relay.op.transform.reshape(weight, new_weight_shape)
res = relay.op.nn.conv2d(
data, weight, strides=strides, padding=padding,
dilation=dilation, groups=groups, channels=channels
)
if bias is not None:
res = relay.op.nn.bias_add(res, bias)
return res
def convert_conv_transpose2d(self, inputs, input_types):
data = inputs[0]
weight = inputs[1]
bias = inputs[2]
strides = inputs[3]
padding = inputs[4]
output_padding = inputs[5]
groups = inputs[6]
dilation = inputs[7]
input_channels, channels, kh, kw = list(self.infer_shape(weight)) #
IOHW
if groups > 1 and channels == 1:
channel_multiplier = channels // groups
new_weight_shape = (groups, channel_multiplier, kh, kw)
weight = relay.op.transform.reshape(weight, new_weight_shape)
res = relay.op.nn.conv2d_transpose(
data,
weight,
strides=strides,
padding=padding,
output_padding=output_padding,
dilation=dilation,
groups=groups,
)
if bias is not None:
res = relay.op.nn.bias_add(res, bias)
return res
tvm.relay.frontend.pytorch.PyTorchOpConverter = MyPyTorchOpConverter
```
</details>
The implementations of the convertors are adapted from
`tvm.relay.frontend.pytorch.PyTorchOpConverter.convolution(inputs,
input_types)` but updated to support the call signature of
`torch.nn.functional.conv2d/conv_transpose2d`.
The problem I'm seeing is that it seems like
`tvm.relay.op.nn.conv2d_transpose()` doesn't respect the `groups` argument.
When I print the input and outputs of the first 4 conv(_transpose) ops in my
network, the PyTorch shapes are the following:
<details>
<summary>PyTorch shapes</summary>
```
conv2d
input (1, 1536, 4, 4)
weight (1536, 512, 3, 3)
groups 3
out (1, 1536, 4, 4)
conv2d
input (1, 1536, 4, 4)
weight (9, 512, 1, 1)
groups 3
out (1, 9, 4, 4)
conv_transpose2d
input (1, 1536, 4, 4)
weight (1536, 512, 3, 3)
groups 3
out (1, 1536, 9, 9)
conv2d
data (1, 1536, 11, 11)
weight (1536, 1, 4, 4)
groups 1536
out (1, 1536, 8, 8)
```
</details>
While the TVM shapes are:
<details>
<summary>TVM shapes</summary>
```
conv2d
input [1, 1536, 4, 4]
weight [1536, 512, 3, 3]
groups 3
out (1, 1536, 4, 4)
conv2d
input [1, 1536, 4, 4]
weight [9, 512, 1, 1]
groups 3
out (1, 9, 4, 4)
conv2d_transpose
input [1, 1536, 4, 4]
weight [1536, 512, 3, 3]
groups 3
out (1, 512, 9, 9)
conv2d
input [1, 512, 11, 11]
weight [512, 1, 4, 4]
groups 1536
TVMError
```
</details>
Notice that the output shape of `tvm.relay.op.nn.conv2d_transpose()` does
not have the correct number of channels (output is as if `groups` = 1). This
leads to the error in the next conv2d operation:
<details>
<summary>Error traceback</summary>
```
Traceback (most recent call last):
File "/home/hans/code/stylegan3/func.py", line 217, in <module>
Gtvm, tvm_params = relay.frontend.pytorch.from_pytorch(
File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line
4010, in from_pytorch
outputs =
converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs,
ret_name)
File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line
3385, in convert_operators
relay_out = relay_op(
File "/home/hans/code/stylegan3/func.py", line 101, in convert_conv2d
print("out", self.infer_shape(res))
File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 204,
in infer_shape
typ = self.infer_type(inputs, mod=mod)
File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 162,
in infer_type
new_mod = transform.InferType()(new_mod)
File "/home/hans/code/tvm/python/tvm/ir/transform.py", line 161, in
__call__
return _ffi_transform_api.RunPass(self, mod)
File "/home/hans/code/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line
237, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
7: TVMFuncCall
6:
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
(tvm::transform::Pass,
tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass,
tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass,
tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>,
std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*,
tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
5: tvm::transform::Pass::operator()(tvm::IRModule) const
4: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
3: tvm::transform::ModulePassNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
2:
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
(tvm::IRModule,
tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule,
tvm::transform::PassContext
const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule,
tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*,
tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
1: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
0: tvm::relay::TypeSolver::Solve() [clone .cold]
9: TVMFuncCall
8:
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
(tvm::transform::Pass,
tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass,
tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass,
tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>,
std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*,
tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
7: tvm::transform::Pass::operator()(tvm::IRModule) const
6: tvm::transform::Pass::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
5: tvm::transform::ModulePassNode::operator()(tvm::IRModule,
tvm::transform::PassContext const&) const
4:
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
(tvm::IRModule,
tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule,
tvm::transform::PassContext
const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule,
tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&,
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*,
tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
3: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
2: tvm::relay::TypeSolver::Solve()
1:
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<bool
(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&,
tvm::TypeReporter const&)>::AssignTypedLambda<bool
(*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&,
tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void>
const&, int, tvm::Attrs const&, tvm::TypeReporter
const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>
>::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs,
tvm::runtime::TVMRetValue*)
0: tvm::relay::ReshapeRel(tvm::runtime::Array<tvm::Type, void> const&,
int, tvm::Attrs const&, tvm::TypeReporter const&)
File "/home/hans/code/tvm/src/relay/analysis/type_solver.cc", line 624
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: (false) is false: [14:35:14]
/home/hans/code/tvm/src/relay/op/tensor/transform.cc:787:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: oshape_sum == data_shape_sum (24576 vs. 8192) : Input tensor
shape(1536,1,4,4) and reshaped shape(512,1,4,4) are not compatible!
```
</details>
As a workaround, I've rewritten my conv_transpose2d converter to manually
split the data and weights into groups, perform each transposed conv, and then
concatenate them back. This converter does seem to give the correct output
shape although I haven't yet tested the outputs for correctness, I might have
just gotten lucky with the shapes.
<details>
<summary>Workaround converter implementation (manual grouping)</summary>
```py
def convert_conv_transpose2d_workaround(self, inputs, input_types):
data = inputs[0]
weight = inputs[1]
bias = inputs[2]
strides = inputs[3]
padding = inputs[4]
output_padding = inputs[5]
groups = inputs[6]
dilation = inputs[7]
input_channels, channels, kh, kw = list(self.infer_shape(weight))
if groups > 1 and channels == 1:
channel_multiplier = channels // groups
new_weight_shape = (groups, channel_multiplier, kh, kw)
weight = relay.op.transform.reshape(weight, new_weight_shape)
datas = relay.op.split(data, groups, axis=1)
weights = relay.op.split(weight, groups, axis=0)
rs = []
for d, w in zip(datas, weights):
r = relay.op.nn.conv2d_transpose(
d, w, strides=strides, padding=padding,
output_padding=output_padding, dilation=dilation, groups=1
)
if bias is not None:
r = relay.op.nn.bias_add(r, bias)
rs.append(r)
res = relay.op.concatenate(rs, axis=1)
return res
```
</details>
### Expected behavior
The groups argument of `tvm.relay.op.nn.conv2d_transpose` should work
correctly like `tvm.relay.op.nn.conv2d` does.
### Actual behavior
The transposed convolution seems to only be applied to a single group?
### Environment
Ubuntu 20.04
PyTorch 1.12.0.dev20220210
TVM 0.9.dev525+g8aeb72265 (compiled from main a couple hours ago)
CUDA 11.4
### Steps to reproduce
```py
from copy import deepcopy
import torch
import tvm.relay
from torch.nn.functional import conv_transpose2d
_original_get_constant = deepcopy(tvm.relay.frontend.pytorch._get_constant)
def _my_get_constant(node):
"""Monkey patch in support for prim::Constant lists, I guess
torch.jit.optimize_for_inference introduces these?"""
if node.output().type().kind() == "ListType":
print("WARNING: Encountered ListType in _get_constant, doing weird
eval stuff to get the list value:", end=" ")
lst = eval(node.__repr__().split("value=")[1].replace("]()", ""))
print(lst)
return lst
else:
return _original_get_constant(node)
tvm.relay.frontend.pytorch._get_constant = _my_get_constant
def convert_conv_transpose2d(inputs, input_types):
data = inputs[0]
weight = inputs[1]
bias = inputs[2]
strides = inputs[3]
padding = inputs[4]
output_padding = inputs[5]
groups = inputs[6]
dilation = inputs[7]
res = tvm.relay.op.nn.conv2d_transpose(
data,
weight,
strides=strides,
padding=padding,
output_padding=output_padding,
dilation=dilation,
groups=groups,
)
if bias is not None:
res = tvm.relay.op.nn.bias_add(res, bias)
return res
class ModulatedConvTranspose2D(torch.nn.Module):
def forward(self, x, w, s):
B, C, H, W = x.shape
I, O, KH, KW = w.shape
# weight is different for each input in batch (this is why we want
grouped conv transpose)
w = w.unsqueeze(0) * s.reshape(B, 1, 1, 1, 1)
w = w.reshape(B * I, O, KH, KW)
x = x.reshape(1, B * C, H, W)
x = conv_transpose2d(x, w, stride=(2, 2), padding=(1, 1),
output_padding=(1, 1), groups=B)
# Check failed: oshape_sum == data_shape_sum (524288 vs. 131072) :
Input tensor shape(4,256,16,32) and reshaped shape(1,256,16,32) are not
compatible!
x = x.reshape(B, O, H * 2, W * 2)
return x
with torch.inference_mode():
b, c, h, w, k = 4, 512, 8, 16, 3
inputs = torch.rand(b, c, h, w)
weights = torch.rand(c, c // 2, k, k)
styles = torch.rand(b)
torch_mod = torch.jit.optimize_for_inference(
torch.jit.trace(ModulatedConvTranspose2D().eval(), (inputs, weights,
styles))
)
outputs_torch = torch_mod(inputs, weights, styles)
print("Torch output shape", outputs_torch.shape) # torch.Size([4, 256,
16, 32])
tvm_mod, params = tvm.relay.frontend.pytorch.from_pytorch(
torch_mod,
[("inputs", inputs.shape), ("weights", weights.shape), ("styles",
styles.shape)],
{"aten::conv_transpose2d": convert_conv_transpose2d},
)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]