comaniac commented on pull request #7675:
URL: https://github.com/apache/tvm/pull/7675#issuecomment-800626924


   Pushed a new commit to also reorder the reshape_b and transpose so that the 
simplify expression can be used.
   
   Before this PR:
   
   ```
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(10, 4, 5), 
float32]) -> Tensor[(10, 3, 5), float32] {
     %0 = reshape(%input0, newshape=[-1, 3, 4]) /* ty=Tensor[(10, 3, 4), 
float32] */;
     %1 = reshape(%input1, newshape=[-1, 4, 5]) /* ty=Tensor[(10, 4, 5), 
float32] */;
     %2 = transpose(%1, axes=[0, 2, 1]) /* ty=Tensor[(10, 5, 4), float32] */;
     %3 = nn.batch_matmul(%0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* 
ty=Tensor[(10, 3, 5), float32] */;
     reshape(%3, newshape=[10, 3, 5]) /* ty=Tensor[(10, 3, 5), float32] */
   }
   
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(4, 5), float32]) 
-> Tensor[(10, 3, 5), float32] {
     %0 = reshape(%input0, newshape=[-1, 3, 4]) /* ty=Tensor[(10, 3, 4), 
float32] */;
     %1 = reshape(%input1, newshape=[-1, 4, 5]) /* ty=Tensor[(1, 4, 5), 
float32] */;
     %2 = transpose(%1, axes=[0, 2, 1]) /* ty=Tensor[(1, 5, 4), float32] */;
     %3 = nn.batch_matmul(%0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* 
ty=Tensor[(10, 3, 5), float32] */;
     reshape(%3, newshape=[10, 3, 5]) /* ty=Tensor[(10, 3, 5), float32] */
   }
   
   fn (%input0: Tensor[(1, 12, 14, 64), float32], %input1: Tensor[(1, 12, 64, 
14), float32]) -> Tensor[(1, 12, 14, 14), float32] {
     %0 = reshape(%input0, newshape=[-1, 14, 64]) /* ty=Tensor[(12, 14, 64), 
float32] */;
     %1 = reshape(%input1, newshape=[-1, 64, 14]) /* ty=Tensor[(12, 64, 14), 
float32] */;
     %2 = transpose(%1, axes=[0, 2, 1]) /* ty=Tensor[(12, 14, 64), float32] */;
     %3 = nn.batch_matmul(%0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* 
ty=Tensor[(12, 14, 14), float32] */;
     reshape(%3, newshape=[1, 12, 14, 14]) /* ty=Tensor[(1, 12, 14, 14), 
float32] */
   }
   ```
   
   After this PR:
   
   ```
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(10, 4, 5), 
float32]) -> Tensor[(10, 3, 5), float32] {
     %0 = transpose(%input1, axes=[0, 2, 1]) /* ty=Tensor[(10, 5, 4), float32] 
*/;
     nn.batch_matmul(%input0, %0, meta[relay.attrs.BatchMatmulAttrs][0]) /* 
ty=Tensor[(10, 3, 5), float32] */
   }
   
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(4, 5), float32]) 
-> Tensor[(10, 3, 5), float32] {
     %0 = transpose(%input1, axes=[1, 0]) /* ty=Tensor[(5, 4), float32] */;
     %1 = reshape(%0, newshape=[-1, 5, 4]) /* ty=Tensor[(1, 5, 4), float32] */;
     nn.batch_matmul(%input0, %1, meta[relay.attrs.BatchMatmulAttrs][0]) /* 
ty=Tensor[(10, 3, 5), float32] */
   }
   
   fn (%input0: Tensor[(1, 12, 14, 64), float32], %input1: Tensor[(1, 12, 64, 
14), float32]) -> Tensor[(1, 12, 14, 14), float32] {
     %0 = reshape(%input0, newshape=[-1, 14, 64]) /* ty=Tensor[(12, 14, 64), 
float32] */;
     %1 = transpose(%input1, axes=[0, 1, 3, 2]) /* ty=Tensor[(1, 12, 14, 64), 
float32] */;
     %2 = reshape(%1, newshape=[-1, 14, 64]) /* ty=Tensor[(12, 14, 64), 
float32] */;
     %3 = nn.batch_matmul(%0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* 
ty=Tensor[(12, 14, 14), float32] */;
     reshape(%3, newshape=[1, 12, 14, 14]) /* ty=Tensor[(1, 12, 14, 14), 
float32] */
   }
   ```
   
   In particular, since the weights in most PyTorch models have to be 
transposed when converting to Relay, the second case, for example, could be:
   
   ```
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(5, 4), float32]) 
-> Tensor[(10, 3, 5), float32] {
     %0 = transpose(%input1, axes=[1, 0]) /* ty=Tensor[(4, 5), float32] */; <- 
Not added by matmul
     %1 = transpose(%0, axes=[1, 0]) /* ty=Tensor[(5, 4), float32] */; <- Added 
by matmul
     %2 = reshape(%1, newshape=[-1, 5, 4]) /* ty=Tensor[(1, 5, 4), float32] */;
     nn.batch_matmul(%input0, %2, meta[relay.attrs.BatchMatmulAttrs][0]) /* 
ty=Tensor[(10, 3, 5), float32] */
   }
   ```
   
   By applying SimplifyExpr to cancel unnecessary `transpose`, we could have:
   
   ```
   fn (%input0: Tensor[(10, 3, 4), float32], %input1: Tensor[(5, 4), float32]) 
-> Tensor[(10, 3, 5), float32] {
     %0 = reshape(%input1, newshape=[-1, 5, 4]) /* ty=Tensor[(1, 5, 4), 
float32] */;
     nn.batch_matmul(%input0, %0, meta[relay.attrs.BatchMatmulAttrs][0]) /* 
ty=Tensor[(10, 3, 5), float32] */
   }
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to