JCBrouwer opened a new issue #10223:
URL: https://github.com/apache/tvm/issues/10223


   I'm trying to convert a PyTorch model which makes use of 
`torch.nn.functional.conv_transpose2d` and am running into issues with my 
converter to the corresponding `tvm.relay.op.nn.conv2d_transpose` operation.
   
   I've done a little monkey patching on the PyTorchOpConverter as the 
operations that `torch.nn.functional.conv2d/conv_transpose2d` trace to 
(`aten::conv2d` and `aten::conv_transpose2d`) aren't covered by default. I've 
added functions to convert each one to the PyTorchOpConverter so that I have 
access to `self.infer_shape(weight)` in the functions as follows:
   
   <details>
     <summary>Converter implementation</summary>
   
   ```py
   class MyPyTorchOpConverter(PyTorchOpConverter):
       def __init__(self, prelude, default_dtype):
           super().__init__(prelude, default_dtype)
           self.update_convert_map(
               {"aten::conv2d": self.convert_conv2d, "aten::conv_transpose2d": 
self.convert_conv_transpose2d}
           )
   
       def convert_conv2d(self, inputs, input_types):
           data = inputs[0]
           weight = inputs[1]
           bias = inputs[2]
           strides = inputs[3]
           padding = inputs[4]
           dilation = inputs[5]
           groups = inputs[6]
   
           channels, input_channels, kh, kw = self.infer_shape(weight)  # OIHW
   
           if groups > 1 and input_channels == 1:
               channel_multiplier = channels // groups
               new_weight_shape = (groups, channel_multiplier, kh, kw)
               weight = relay.op.transform.reshape(weight, new_weight_shape)
   
           res = relay.op.nn.conv2d(
               data, weight, strides=strides, padding=padding, 
dilation=dilation, groups=groups, channels=channels
           )
           if bias is not None:
               res = relay.op.nn.bias_add(res, bias)
   
           return res
   
       def convert_conv_transpose2d(self, inputs, input_types):
           data = inputs[0]
           weight = inputs[1]
           bias = inputs[2]
           strides = inputs[3]
           padding = inputs[4]
           output_padding = inputs[5]
           groups = inputs[6]
           dilation = inputs[7]
   
           input_channels, channels, kh, kw = list(self.infer_shape(weight))  # 
IOHW
   
           if groups > 1 and channels == 1:
               channel_multiplier = channels // groups
               new_weight_shape = (groups, channel_multiplier, kh, kw)
               weight = relay.op.transform.reshape(weight, new_weight_shape)
   
           res = relay.op.nn.conv2d_transpose(
               data,
               weight,
               strides=strides,
               padding=padding,
               output_padding=output_padding,
               dilation=dilation,
               groups=groups,
           )
           if bias is not None:
               res = relay.op.nn.bias_add(res, bias)
   
           return res
   
   tvm.relay.frontend.pytorch.PyTorchOpConverter = MyPyTorchOpConverter
   ```
   
   </details>
   
   The implementations of the convertors are adapted from 
`tvm.relay.frontend.pytorch.PyTorchOpConverter.convolution(inputs, 
input_types)` but updated to support the call signature of 
`torch.nn.functional.conv2d/conv_transpose2d`.
   
   The problem I'm seeing is that it seems like 
`tvm.relay.op.nn.conv2d_transpose()` doesn't respect the `groups` argument. 
When I print the input and outputs of the first 4 conv(_transpose) ops in my 
network, the PyTorch shapes are the following:
   
   <details>
     <summary>PyTorch shapes</summary>
   
   ```
   conv2d
   input (1, 1536, 4, 4)
   weight (1536, 512, 3, 3)
   groups 3
   out (1, 1536, 4, 4)
   
   conv2d
   input (1, 1536, 4, 4)
   weight (9, 512, 1, 1)
   groups 3
   out (1, 9, 4, 4)
   
   conv_transpose2d
   input (1, 1536, 4, 4)
   weight (1536, 512, 3, 3)
   groups 3
   out (1, 1536, 9, 9)
   
   conv2d
   data (1, 1536, 11, 11)
   weight (1536, 1, 4, 4)
   groups 1536
   out (1, 1536, 8, 8)
   ```
   
   </details>
   
   While the TVM shapes are:
   <details>
     <summary>TVM shapes</summary>
   
   ```
   conv2d
   input [1, 1536, 4, 4]
   weight [1536, 512, 3, 3]
   groups 3
   out (1, 1536, 4, 4)
   
   conv2d
   input [1, 1536, 4, 4]
   weight [9, 512, 1, 1]
   groups 3
   out (1, 9, 4, 4)
   
   conv2d_transpose
   input [1, 1536, 4, 4]
   weight [1536, 512, 3, 3]
   groups 3
   out (1, 512, 9, 9)
   
   conv2d
   input [1, 512, 11, 11]
   weight [512, 1, 4, 4]
   groups 1536
   TVMError
   ```
   
   </details>
   
   Notice that the output shape of `tvm.relay.op.nn.conv2d_transpose()` does 
not have the correct number of channels (output is as if `groups` = 1). This 
leads to the error in the next conv2d operation:
   
   <details>
     <summary>Error traceback</summary>
   
   ```
   Traceback (most recent call last):
     File "/home/hans/code/stylegan3/func.py", line 217, in <module>
       Gtvm, tvm_params = relay.frontend.pytorch.from_pytorch(
     File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 
4010, in from_pytorch
       outputs = 
converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, 
ret_name)
     File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 
3385, in convert_operators
       relay_out = relay_op(
     File "/home/hans/code/stylegan3/func.py", line 101, in convert_conv2d
       print("out", self.infer_shape(res))
     File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 204, 
in infer_shape
       typ = self.infer_type(inputs, mod=mod)
     File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 162, 
in infer_type
       new_mod = transform.InferType()(new_mod)
     File "/home/hans/code/tvm/python/tvm/ir/transform.py", line 161, in 
__call__
       return _ffi_transform_api.RunPass(self, mod)
     File "/home/hans/code/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 
237, in __call__
       raise get_last_ffi_error()
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     7: TVMFuncCall
     6: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
 (tvm::transform::Pass, 
tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass,
 tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, 
tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, 
std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, 
tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
     5: tvm::transform::Pass::operator()(tvm::IRModule) const
     4: tvm::transform::Pass::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     2: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
 (tvm::IRModule, 
tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule,
 tvm::transform::PassContext 
const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, 
tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, 
tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
     1: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
     0: tvm::relay::TypeSolver::Solve() [clone .cold]
     9: TVMFuncCall
     8: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
 (tvm::transform::Pass, 
tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass,
 tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, 
tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, 
std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, 
tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
     7: tvm::transform::Pass::operator()(tvm::IRModule) const
     6: tvm::transform::Pass::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
     4: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule
 (tvm::IRModule, 
tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule,
 tvm::transform::PassContext 
const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, 
tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, 
tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, 
tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
     3: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
     2: tvm::relay::TypeSolver::Solve()
     1: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<bool
 (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, 
tvm::TypeReporter const&)>::AssignTypedLambda<bool 
(*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, 
tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> 
const&, int, tvm::Attrs const&, tvm::TypeReporter 
const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> 
>::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)
     0: tvm::relay::ReshapeRel(tvm::runtime::Array<tvm::Type, void> const&, 
int, tvm::Attrs const&, tvm::TypeReporter const&)
     File "/home/hans/code/tvm/src/relay/analysis/type_solver.cc", line 624
   TVMError: 
   ---------------------------------------------------------------
   An error occurred during the execution of TVM.
   For more information, please see: https://tvm.apache.org/docs/errors.html
   ---------------------------------------------------------------
     Check failed: (false) is false: [14:35:14] 
/home/hans/code/tvm/src/relay/op/tensor/transform.cc:787: 
   ---------------------------------------------------------------
   An error occurred during the execution of TVM.
   For more information, please see: https://tvm.apache.org/docs/errors.html
   ---------------------------------------------------------------
   
     Check failed: oshape_sum == data_shape_sum (24576 vs. 8192) : Input tensor 
shape(1536,1,4,4) and reshaped shape(512,1,4,4) are not compatible!
   ```
   
   </details>
   
   As a workaround, I've rewritten my conv_transpose2d converter to manually 
split the data and weights into groups, perform each transposed conv, and then 
concatenate them back. This converter does seem to give the correct output 
shape although I haven't yet tested the outputs for correctness, I might have 
just gotten lucky with the shapes.
   
   
   <details>
     <summary>Workaround converter implementation (manual grouping)</summary>
   
   ```py
       def convert_conv_transpose2d_workaround(self, inputs, input_types):
           data = inputs[0]
           weight = inputs[1]
           bias = inputs[2]
           strides = inputs[3]
           padding = inputs[4]
           output_padding = inputs[5]
           groups = inputs[6]
           dilation = inputs[7]
   
           input_channels, channels, kh, kw = list(self.infer_shape(weight))
   
           if groups > 1 and channels == 1:
               channel_multiplier = channels // groups
               new_weight_shape = (groups, channel_multiplier, kh, kw)
               weight = relay.op.transform.reshape(weight, new_weight_shape)
   
           datas = relay.op.split(data, groups, axis=1)
           weights = relay.op.split(weight, groups, axis=0)
   
           rs = []
           for d, w in zip(datas, weights):
               r = relay.op.nn.conv2d_transpose(
                   d, w, strides=strides, padding=padding, 
output_padding=output_padding, dilation=dilation, groups=1
               )
               if bias is not None:
                   r = relay.op.nn.bias_add(r, bias)
               rs.append(r)
           res = relay.op.concatenate(rs, axis=1)
   
           return res
   ```
   
   </details>
   
   ### Expected behavior
   
   The groups argument of `tvm.relay.op.nn.conv2d_transpose` should work 
correctly like `tvm.relay.op.nn.conv2d` does.
   
   ### Actual behavior
   
   The transposed convolution seems to only be applied to a single group?
   
   ### Environment
   
   Ubuntu 20.04
   PyTorch 1.12.0.dev20220210
   TVM 0.9.dev525+g8aeb72265 (compiled from main a couple hours ago)
   CUDA 11.4
   
   ### Steps to reproduce
   
   ```py
   from copy import deepcopy
   
   import torch
   import tvm.relay
   from torch.nn.functional import conv_transpose2d
   
   _original_get_constant = deepcopy(tvm.relay.frontend.pytorch._get_constant)
   
   
   def _my_get_constant(node):
       """Monkey patch in support for prim::Constant lists, I guess 
torch.jit.optimize_for_inference introduces these?"""
       if node.output().type().kind() == "ListType":
           print("WARNING: Encountered ListType in _get_constant, doing weird 
eval stuff to get the list value:", end=" ")
           lst = eval(node.__repr__().split("value=")[1].replace("]()", ""))
           print(lst)
           return lst
       else:
           return _original_get_constant(node)
   
   
   tvm.relay.frontend.pytorch._get_constant = _my_get_constant
   
   
   def convert_conv_transpose2d(inputs, input_types):
       data = inputs[0]
       weight = inputs[1]
       bias = inputs[2]
       strides = inputs[3]
       padding = inputs[4]
       output_padding = inputs[5]
       groups = inputs[6]
       dilation = inputs[7]
   
       res = tvm.relay.op.nn.conv2d_transpose(
           data,
           weight,
           strides=strides,
           padding=padding,
           output_padding=output_padding,
           dilation=dilation,
           groups=groups,
       )
       if bias is not None:
           res = tvm.relay.op.nn.bias_add(res, bias)
   
       return res
   
   
   class ModulatedConvTranspose2D(torch.nn.Module):
       def forward(self, x, w, s):
           B, C, H, W = x.shape
           I, O, KH, KW = w.shape
   
           # weight is different for each input in batch (this is why we want 
grouped conv transpose)
           w = w.unsqueeze(0) * s.reshape(B, 1, 1, 1, 1)
           w = w.reshape(B * I, O, KH, KW)
   
           x = x.reshape(1, B * C, H, W)
   
           x = conv_transpose2d(x, w, stride=(2, 2), padding=(1, 1), 
output_padding=(1, 1), groups=B)
   
           # Check failed: oshape_sum == data_shape_sum (524288 vs. 131072) : 
Input tensor shape(4,256,16,32) and reshaped shape(1,256,16,32) are not 
compatible!
           x = x.reshape(B, O, H * 2, W * 2)
   
           return x
   
   
   with torch.inference_mode():
       b, c, h, w, k = 4, 512, 8, 16, 3
       inputs = torch.rand(b, c, h, w)
       weights = torch.rand(c, c // 2, k, k)
       styles = torch.rand(b)
   
       torch_mod = torch.jit.optimize_for_inference(
           torch.jit.trace(ModulatedConvTranspose2D().eval(), (inputs, weights, 
styles))
       )
   
       outputs_torch = torch_mod(inputs, weights, styles)
       print("Torch output shape", outputs_torch.shape)  # torch.Size([4, 256, 
16, 32])
   
       tvm_mod, params = tvm.relay.frontend.pytorch.from_pytorch(
           torch_mod,
           [("inputs", inputs.shape), ("weights", weights.shape), ("styles", 
styles.shape)],
           {"aten::conv_transpose2d": convert_conv_transpose2d},
       )
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to