Lyken17 opened a new pull request #9465:
URL: https://github.com/apache/tvm/pull/9465
The default shape format of TVM is `N x Cx iH x iW` for input and `O x I x
kH x kW` for weight, a proper shape for Conv2dTransposed should
* input: (batch, in_channels, iH, iW)
* weight: (out_channels, in_channels // groups, kH, kW)
Thus the original checking
```
ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups),
wshape[0]));
```
is wrong. The proper comparison dimension should be `wshape[1]` rather than
`wshape[0]`.
Besides, the name for debug is not correct either. All logging information
are using `conv2d` rather than `conv2d_transposed`, which is confusing.
## Example to trigger error in current implementation
```python
import torch
import torch as th
import torch.nn as nn
from torchvision import models
import torch.onnx
import numpy as np
import tvm
from tvm import relay
from tvm import relay, auto_scheduler
from tvm.relay import testing
SEMVER = '#[version = "0.0.5"]\n'
def assert_graph_equal(lhs, rhs):
tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars=True)
def roundtrip(expr):
x = tvm.parser.fromtext(expr.astext())
assert_graph_equal(x, expr)
# Testing Utilities for full modules.
def parse_module(code):
mod = tvm.parser.parse(SEMVER + code)
roundtrip(mod)
return mod
program = """
def @main(%input0: Tensor[(1, 32, 224, 224), float32],
%v0_0_weight: Tensor[(32, 1, 3, 3), float32]) -> Tensor[(1, 32, 224,
224), float32] {
/* test comment */
%0 = nn.conv2d_transpose(%input0, %v0_0_weight, strides=[1, 1],
padding=[1, 1, 1, 1], groups=32, channels=32, kernel_size=[3, 3]);
%0
}
"""
mod = parse_module(program)
print(mod)
target = "llvm"
lib = relay.build(mod, target=target, params=None)
print("build [fwd] pass successful")
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]