lazycal opened a new issue #8759:
URL: https://github.com/apache/tvm/issues/8759
When doing AlterLayout pass on a conv followed by a strided_slice from
``NCHW4c`` to `NCHW`, the compiler does nothing to strided_slice, while (I
think) the only correct behavior should be wrapping it with two
layout_transforms. This leads to incorrect numerical result/crash at InferType,
depending on the concrete input shape,
as shown in the following code snippet:
```python
import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing.temp_op_attr import TempOpAttr
import numpy as np
def test1(x_shape, w_shape):
def before():
x = relay.var("x", shape=x_shape)
weight = relay.var("weight", shape=w_shape)
y = relay.nn.conv2d(
x,
weight,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW4c",
kernel_layout="OIHW4i4o",
)
y = relay.strided_slice(y, begin=[0, 0], end=[1, -1], strides=[1, 8])
y = relay.Function([x, weight], y)
return tvm.IRModule.from_expr(y)
def alter_conv2d(attrs, inputs, tinfos, out_type):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs["data_layout"] = "NCHW"
new_attrs["kernel_layout"] = "OIHW"
return relay.nn.conv2d(data, weight, **new_attrs)
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
be = transform.InferType()(before())
print('='*40, 'before', '='*40)
print(be)
af = transform.AlterOpLayout()(be)
print('='*40, 'after', '='*40)
print(af)
xnp = np.random.rand(*x_shape).astype(np.float32)
wnp = np.random.rand(*w_shape).astype(np.float32)
be_res = relay.create_executor("debug", be).evaluate()(xnp,
wnp).numpy()
af_res = relay.create_executor("debug", af).evaluate()(xnp,
wnp).numpy()
tvm.testing.assert_allclose(be_res, af_res, rtol=1e-3, atol=1e-3)
test1(x_shape=(1, 1, 1, 1, 4), w_shape=(9, 1, 3, 3, 4, 4)) # incorrect
numerical result
# test1(x_shape=(1, 1, 1, 1, 4), w_shape=(11, 1, 3, 3, 4, 4)) # crash at
InferType
```
The module before:
```cpp
def @main(%x: Tensor[(1, 1, 1, 1, 4), float32], %weight: Tensor[(9, 1, 3, 3,
4, 4), float32]) -> Tensor[(1, 1, 1, 1, 4), float32] {
%0 = nn.conv2d(%x, %weight, padding=[1, 1, 1, 1], kernel_size=[3, 3],
data_layout="NCHW4c", kernel_layout="OIHW4i4o") /* ty=Tensor[(1, 9, 1, 1, 4),
float32] */;
strided_slice(%0, begin=[0, 0], end=[1, -1], strides=[1, 8], axes=None) /*
ty=Tensor[(1, 1, 1, 1, 4), float32] */
}
```
and after:
```cpp
def @main(%x: Tensor[(1, 1, 1, 1, 4), float32], %weight: Tensor[(9, 1, 3, 3,
4, 4), float32]) -> Tensor[(1, 1, 1, 1, 4), float32] {
%0 = layout_transform(%x, src_layout="NCHW4c", dst_layout="NCHW") /*
ty=Tensor[(1, 4, 1, 1), float32] */;
%1 = layout_transform(%weight, src_layout="OIHW4i4o", dst_layout="OIHW")
/* ty=Tensor[(36, 4, 3, 3), float32] */;
%2 = nn.conv2d(%0, %1, padding=[1, 1, 1, 1], kernel_size=[3, 3]) /*
ty=Tensor[(1, 36, 1, 1), float32] */;
%3 = strided_slice(%2, begin=[0, 0], end=[1, -1], strides=[1, 8],
axes=None) /* ty=Tensor[(1, 5, 1, 1), float32] */;
layout_transform(%3, src_layout="NCHW", dst_layout="NCHW4c") /*
ty=Tensor[(1, 1, 1, 1, 4), float32] */
}
```
Specifically, I am doing `conv_NCHW4c_out[;,::8,...]` (a 8-stride slice at
the primal `C` dimension of `NCHW4c`). After altering layout into `NCHW`, the
compiler does not wrap strided_slice with any layout_transformations nor adjust
its attributes, so the semantic gets changed to `conv_NCHW_out[:,::8,...]`,
which means picking 1 every 8 elements, while what we need is to pick 4
elements every 4*8=32 elements for `conv_NCHW_out`
It seems that `StridedSliceInferCorrectLayout` is responsible for this.
BTW, the layout_transform seems weird in the latter IR:
```cpp
%3 = strided_slice(%2, begin=[0, 0], end=[1, -1], strides=[1, 8], axes=None)
/* ty=Tensor[(1, 5, 1, 1), float32] */;
layout_transform(%3, src_layout="NCHW", dst_layout="NCHW4c") /*
ty=Tensor[(1, 1, 1, 1, 4), float32] */
```
The resultant tensor has smaller shape `(1,1,1,1,4)` than the
before-transform one `(1,5,1,1)`, and the reason I think is that `(1,5,1,1)` is
not a valid input to be converted to the layout of ``NCHW4c``, and I thought
layout_transform should be able to detect and reject that?
## Environment
- TVM: commit e334942db002019979438971440d33ece16585a3
- CUDA version: 10.0
- System: Ubuntu 16.04
- GCC 5.4
- Build options: -DUSE_RELAY_DEBUG=ON -DUSE_CUBLAS=ON -DUSE_LLVM=ON
-DUSE_CUDA=ON -DUSE_CUDNN=ON
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]