xinetzone opened a new issue, #17151:
URL: https://github.com/apache/tvm/issues/17151

   Code:
   ```python
   import torch
   from torch import nn, Tensor
   from torch.nn import functional as F
   
   def _make_divisible(v, divisor, min_value=None):
       """
       This function is taken from the original tf repo.
       It ensures that all layers have a channel number that is divisible by 8
       It can be seen here:
       
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
       :param v:
       :param divisor:
       :param min_value:
       :return:
       """
       if min_value is None:
           min_value = divisor
       new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
       # Make sure that round down does not go down by more than 10%.
       if new_v < 0.9 * v:
           new_v += divisor
       return new_v
   
   class SqueezeExcitation(nn.Module):
       def __init__(self, input_channels: int, squeeze_factor: int = 4):
           super().__init__()
           squeeze_channels = _make_divisible(input_channels // squeeze_factor, 
8)
           self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
           self.relu = nn.ReLU(inplace=True)
           self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
   
       def _scale(self, input: Tensor, inplace: bool) -> Tensor:
           scale = F.adaptive_avg_pool2d(input, 1)
           scale = self.fc1(scale)
           scale = self.relu(scale)
           scale = self.fc2(scale)
           return F.hardsigmoid(scale, inplace=inplace)
   
       def forward(self, input: Tensor) -> Tensor:
           scale = self._scale(input, True)
           return scale * input
   
   class M(nn.Module):
       def __init__(self, input_channels: int=16):
           super().__init__()
           self.conv = nn.Conv2d(input_channels, 64, 1, bias=False)
           self.se_layer = SqueezeExcitation(input_channels)
   
       def forward(self, x: Tensor) -> Tensor:
           x = self.se_layer(x)
           x = self.conv(x)
           return x
   
   name = "x"
   shape = (1, 16, 64, 48)
   data_np = (np.random.randint(0, 256, shape)/255).astype("float32")
   data_torch = torch.from_numpy(data_np)
   
   model = M(input_channels=16).eval()
   scripted_model = torch.jit.trace(model, data_torch).eval()
   shape_list = [(name, shape)]
   mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
    ```
   
   has
   
   ````python
   print(mod["main"])
   ```
   
   show:
   
   ```
   fn (%x: Tensor[(1, 16, 64, 48), float32] /* 
span=aten::adaptive_avg_pool2d_0.x:0:0 */, %aten::_convolution_0.weight: 
Tensor[(8, 16, 1, 1), float32] /* span=aten::_convolution_0.weight:0:0 */, 
%aten::_convolution_0.bias: Tensor[(8), float32] /* 
span=aten::_convolution_0.bias:0:0 */, %aten::_convolution_1.weight: 
Tensor[(16, 8, 1, 1), float32] /* span=aten::_convolution_1.weight:0:0 */, 
%aten::_convolution_1.bias: Tensor[(16), float32] /* 
span=aten::_convolution_1.bias:0:0 */, %aten::_convolution_2.weight: 
Tensor[(64, 16, 1, 1), float32] /* span=aten::_convolution_2.weight:0:0 */) {
     %0 = nn.adaptive_avg_pool2d(%x, output_size=[1, 1]) /* 
span=aten::adaptive_avg_pool2d_0:0:0 */;
     %1 = nn.conv2d(%0, %aten::_convolution_0.weight, padding=[0, 0, 0, 0], 
channels=8, kernel_size=[1, 1]) /* span=aten::_convolution_0:0:0 */;
     %2 = nn.bias_add(%1, %aten::_convolution_0.bias) /* 
span=aten::_convolution_0:0:0 */;
     %3 = nn.relu(%2) /* span=aten::relu__0:0:0 */;
     %4 = nn.conv2d(%3, %aten::_convolution_1.weight, padding=[0, 0, 0, 0], 
channels=16, kernel_size=[1, 1]) /* span=aten::_convolution_1:0:0 */;
     %5 = nn.bias_add(%4, %aten::_convolution_1.bias) /* 
span=aten::_convolution_1:0:0 */;
     %6 = add(%5, 3f /* span=aten::hardsigmoid__0:0:0 */) /* 
span=aten::hardsigmoid__0:0:0 */;
     %7 = clip(%6, a_min=0f, a_max=6f) /* span=aten::hardsigmoid__0:0:0 */;
     %8 = divide(%7, 6f /* span=aten::hardsigmoid__0:0:0 */) /* 
span=aten::hardsigmoid__0:0:0 */;
     %9 = multiply(%8, %x) /* span=aten::mul_0:0:0 */;
     nn.conv2d(%9, %aten::_convolution_2.weight, padding=[0, 0, 0, 0], 
channels=64, kernel_size=[1, 1]) /* span=aten::_convolution_2:0:0 */
   }
   ```
   
   BUT
   
   ```python
   from copy import deepcopy
   import tvm
   from tvm import relay
   from tvm.relay.quantize.quantize import _bind_params
   optimize = tvm.transform.Sequential(
       [
           relay.transform.SimplifyInference(),
           relay.transform.FoldConstant(),
           relay.transform.FoldScaleAxis(),
           # relay.transform.CanonicalizeOps(),
           # relay.transform.FoldConstant(),
       ]
   )
   run_mod = deepcopy(mod)
   run_mod["main"] = _bind_params(run_mod["main"], params)
   with tvm.transform.PassContext(opt_level=3):
       # run_mod2 = relay.quantize.prerequisite_optimize(deepcopy(mod), params)
       run_mod = optimize(run_mod)
   print(run_mod["main"])
   ```
   
   show error:
   ```
   fn (%x: Tensor[(1, 16, 64, 48), float32] /* ty=Tensor[(1, 16, 64, 48), 
float32] span=aten::adaptive_avg_pool2d_0.x:0:0 */) -> Tensor[(1, 64, 64, 48), 
float32] {
     %0 = nn.adaptive_avg_pool2d(%x, output_size=[1, 1]) /* ty=Tensor[(1, 16, 
1, 1), float32] span=aten::adaptive_avg_pool2d_0:0:0 */;
     %1 = nn.conv2d(%0, meta[relay.Constant][1] /* ty=Tensor[(8, 16, 1, 1), 
float32] */, padding=[0, 0, 0, 0], channels=8, kernel_size=[1, 1]) /* 
ty=Tensor[(1, 8, 1, 1), float32] span=aten::_convolution_0:0:0 */;
     %2 = nn.bias_add(%1, meta[relay.Constant][2] /* ty=Tensor[(8), float32] 
*/) /* ty=Tensor[(1, 8, 1, 1), float32] span=aten::_convolution_0:0:0 */;
     %3 = nn.relu(%2) /* ty=Tensor[(1, 8, 1, 1), float32] 
span=aten::relu__0:0:0 */;
     %4 = nn.conv2d(%3, meta[relay.Constant][3] /* ty=Tensor[(16, 8, 1, 1), 
float32] */, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1]) /* 
ty=Tensor[(1, 16, 1, 1), float32] span=aten::_convolution_1:0:0 */;
     %5 = nn.bias_add(%4, meta[relay.Constant][4] /* ty=Tensor[(16), float32] 
*/) /* ty=Tensor[(1, 16, 1, 1), float32] span=aten::_convolution_1:0:0 */;
     %6 = add(%5, 3f /* ty=float32 span=aten::hardsigmoid__0:0:0 */) /* 
ty=Tensor[(1, 16, 1, 1), float32] span=aten::hardsigmoid__0:0:0 */;
     %7 = clip(%6, a_min=0f, a_max=6f) /* ty=Tensor[(1, 16, 1, 1), float32] 
span=aten::hardsigmoid__0:0:0 */;
     %8 = divide(%7, 6f /* ty=float32 span=aten::hardsigmoid__0:0:0 */) /* 
ty=Tensor[(1, 16, 1, 1), float32] span=aten::hardsigmoid__0:0:0 */;
     %9 = squeeze(%8, axis=[0, 2, 3]) /* ty=Tensor[(16), float32] */;
     %10 = expand_dims(%9, axis=1, num_newaxis=2) /* ty=Tensor[(16, 1, 1), 
float32] */;
     %11 = multiply(meta[relay.Constant][0] /* ty=Tensor[(64, 16, 1, 1), 
float32] */, %10) /* ty=Tensor[(64, 16, 1, 1), float32] */;
     nn.conv2d(%x, %11, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1]) 
/* ty=Tensor[(1, 64, 64, 48), float32] */
   } /* ty=fn (Tensor[(1, 16, 64, 48), float32]) -> Tensor[(1, 64, 64, 48), 
float32] */
   
   
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to