xinetzone opened a new issue, #15325:
URL: https://github.com/apache/tvm/issues/15325

   ### Test Case
   
   ```python
   from torch import nn
   
   class Model(nn.Module):
       def __init__(self, *args, **kwargs) -> None:
           super().__init__(*args, **kwargs)
           self.conv = nn.Conv2d(3, 16, 3, 1, 1, bias=False)
           self.bn = nn.BatchNorm2d(16)
           self.relu = nn.ReLU(inplace=True)
   
       def forward(self, x):
           x = self.conv(x)
           x = self.bn(x)
           x = self.relu(x)
           return x
   
   torch.set_grad_enabled(False)
   input_shape = [1, 3, 8, 8]
   input_data = torch.rand(input_shape).float()
   pt_model = Model().eval().float()
   input_shapes = [("data", input_shape)]
   traced_model = torch.jit.trace(pt_model, input_data)
   # traced_model 翻译为 TVM 前端模型
   mod, params = relay.frontend.from_pytorch(traced_model, input_shapes)
   mod = relay.transform.InferType()(mod)
   run_mod = relay.quantize.prerequisite_optimize(mod, params)
   partition_mod = relay.quantize.partition()(run_mod)
   prinjt(partition_mod["main"])
   ```
   
   
   ### Expected behavior
   
   ```
   fn (%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] 
span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 16, 8, 8), float32] {
     %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), 
float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* 
ty=Tensor[(1, 16, 8, 8), float32] span=aten::_convolution_0:0:0 */;
     %1 = multiply(%0, meta[relay.Constant][1] /* ty=Tensor[(16, 1, 1), 
float32] */) /* ty=Tensor[(1, 16, 8, 8), float32] */;
     %2 = add(%1, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) 
/* ty=Tensor[(1, 16, 8, 8), float32] */;
     %3 = nn.relu(%2) /* ty=Tensor[(1, 16, 8, 8), float32] 
span=aten::relu__0:0:0 */;
     %4 = annotation.cast_hint(%3, dtype="int8") /* ty=Tensor[(1, 16, 8, 8), 
float32] */;
     annotation.stop_fusion(%4) /* ty=Tensor[(1, 16, 8, 8), float32] */
   } /* ty=fn (Tensor[(1, 3, 8, 8), float32]) -> Tensor[(1, 16, 8, 8), float32] 
*/
   ```
   
   ### Actual behavior
   
   ```
   fn (%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] 
span=aten::_convolution_0.data:0:0 */) -> Tensor[(1, 16, 8, 8), float32] {
     %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), 
float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* 
ty=Tensor[(1, 16, 8, 8), float32] span=aten::_convolution_0:0:0 */;
     %1 = annotation.cast_hint(%0, dtype="int8") /* ty=Tensor[(1, 16, 8, 8), 
float32] */;
     %2 = annotation.stop_fusion(%1) /* ty=Tensor[(1, 16, 8, 8), float32] */;
     %3 = multiply(%2, meta[relay.Constant][1] /* ty=Tensor[(16, 1, 1), 
float32] */) /* ty=Tensor[(1, 16, 8, 8), float32] */;
     %4 = add(%3, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) 
/* ty=Tensor[(1, 16, 8, 8), float32] */;
     %5 = nn.relu(%4) /* ty=Tensor[(1, 16, 8, 8), float32] 
span=aten::relu__0:0:0 */;
     %6 = annotation.cast_hint(%5, dtype="int8") /* ty=Tensor[(1, 16, 8, 8), 
float32] */;
     annotation.stop_fusion(%6) /* ty=Tensor[(1, 16, 8, 8), float32] */
   } /* ty=fn (Tensor[(1, 3, 8, 8), float32]) -> Tensor[(1, 16, 8, 8), float32] 
*/
   ```
   
   ### My strategy
   Change `mul_partition_generic` :
   ```python
   def mul_partition_generic(ref_call, new_args, ctx):
       """Rewrite function for ewise mul for partition for generic devices"""
       return add_partition_generic(ref_call, new_args, ctx)
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to