echuraev opened a new issue, #15358:
URL: https://github.com/apache/tvm/issues/15358

   While working on the PR #15137, I have found that FuseOps pass fuses 
operations in none intuitive way and sometimes ops don't fuse together and 
remain individual PrimFuncs. On this 
[branch](https://github.com/echuraev/tvm/tree/echuraev/fuse_ops_issue) I 
prepared a test script which reproduces this situation.
   
   We have a base function which contains one operation:
   ```
   fn (%x0: Tensor[(10, 20), float32]) {
     multiply(%x0, 2f);
   }
   ```
   
   In the test we create 5 such operations and compute their sum (`base_func0 + 
base_func1 + ... + base_func4`). This Relay function looks in the following way:
   
   ```
   fn (%x0: Tensor[(10, 20), float32], %x1: Tensor[(10, 20), float32], %x2: 
Tensor[(10, 20), float32], %x3: Tensor[(10, 20), float32], %x4: Tensor[(10, 
20), float32]) {
     %0 = multiply(%x0, 2f);
     %1 = multiply(%x1, 2f);
     %2 = add(%0, %1);
     %3 = multiply(%x2, 2f);
     %4 = add(%2, %3);
     %5 = multiply(%x3, 2f);
     %6 = add(%4, %5);
     %7 = multiply(%x4, 2f);
     add(%6, %7)
   }
   ```
   
   We want to specify fusing depth that each PrimFunc will contain maximum two 
base functions. In this case `max_fused_ops = (base_function_ops + 1) * 
number_of_fused_base_func`, where `base_function_ops = 1` is the number of 
operations in base func, `number_of_fused_base_func = 2` is the maximum number 
of base functions in one PrimFunc.
   
   In the formula we add one to `base_function_ops`, because if we want to fuse 
`N` base functions into one function, then for each base function we will have 
additionally `N-1` `add` operations and `+1` `add` operation for the previous 
result.
   
   ### Expected behavior
   After fusing algorithm, I expected to see the code that fuse several base 
functions into one PrimFunc. E.g. in the code below, 5 base functions were 
fused into 3 PrimFuncs. The first and the second PrimFuncs contains 4 base 
functions, and the last PrimFunc computes the result of computation for 4 base 
functions with the last one.
   
   ```
   fn (%x0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x1: 
Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x01: Tensor[(10, 
20), float32] /* ty=Tensor[(10, 20), float32] */, %x2: Tensor[(10, 20), 
float32] /* ty=Tensor[(10, 2
   0), float32] */, %x02: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), 
float32] */) -> Tensor[(10, 20), float32] {
     %6 = fn (%p02: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] 
*/, %p12: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, 
Primitive=1) -> Tensor[(10, 20), float32] {
       %4 = multiply(%p02, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] 
*/;
       %5 = multiply(%p12, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] 
*/;
       add(%4, %5) /* ty=Tensor[(10, 20), float32] */
     } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> 
Tensor[(10, 20), float32] */;
     %7 = %6(%x0, %x1) /* ty=Tensor[(10, 20), float32] */;
     %8 = fn (%p01: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] 
*/, %p11: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p2: 
Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> 
Tensor[(10, 20), float32]
   {
       %1 = multiply(%p01, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] 
*/;
       %2 = add(%p11, %1) /* ty=Tensor[(10, 20), float32] */;
       %3 = multiply(%p2, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] 
*/;
       add(%2, %3) /* ty=Tensor[(10, 20), float32] */
     } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], 
Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
     %9 = %8(%x01, %7, %x2) /* ty=Tensor[(10, 20), float32] */;
     %10 = fn (%p0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] 
*/, %p1: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, 
Primitive=1) -> Tensor[(10, 20), float32] {
       %0 = multiply(%p0, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] 
*/;
       add(%p1, %0) /* ty=Tensor[(10, 20), float32] */
     } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> 
Tensor[(10, 20), float32] */;
     %10(%x02, %9) /* ty=Tensor[(10, 20), float32] */
   } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], 
Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), 
float32]) -> Tensor[(10, 20), float32] */
   ```
   
   ### Actual behavior
   Actually, with the current FuseOps pass implementation, I see a bit 
different code. There are 5 PrimFuncs and each of these PrimFuncs contains one 
base function. I suppose it is an incorrect behavior, please correct me if I'm 
wrong.
   
   ```
   fn (%x0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x1: 
Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x2: Tensor[(10, 
20), float32] /* ty=Tensor[(10, 20), float32] */, %x3: Tensor[(10, 20), 
float32] /* ty=Tensor[(10, 20
   ), float32] */, %x4: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), 
float32] */) -> Tensor[(10, 20), float32] {
     %4 = fn (%p02: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] 
*/, Primitive=1) -> Tensor[(10, 20), float32] {
       multiply(%p02, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */
     } /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
     %5 = fn (%p03: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] 
*/, Primitive=1) -> Tensor[(10, 20), float32] {
       multiply(%p03, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */
     } /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
     %6 = %4(%x1) /* ty=Tensor[(10, 20), float32] */;
     %7 = %5(%x2) /* ty=Tensor[(10, 20), float32] */;
     %8 = fn (%p01: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] 
*/, %p11: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p21: 
Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> 
Tensor[(10, 20), float32]
    {
       %2 = multiply(%p01, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] 
*/;
       %3 = add(%2, %p11) /* ty=Tensor[(10, 20), float32] */;
       add(%3, %p21) /* ty=Tensor[(10, 20), float32] */
     } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], 
Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
     %9 = fn (%p04: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] 
*/, Primitive=1) -> Tensor[(10, 20), float32] {
       multiply(%p04, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */
     } /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
     %10 = %8(%x0, %6, %7) /* ty=Tensor[(10, 20), float32] */;
     %11 = %9(%x4) /* ty=Tensor[(10, 20), float32] */;
     %12 = fn (%p0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] 
*/, %p1: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p2: 
Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> 
Tensor[(10, 20), float32] {
       %0 = multiply(%p0, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] 
*/;
       %1 = add(%p1, %0) /* ty=Tensor[(10, 20), float32] */;
       add(%1, %p2) /* ty=Tensor[(10, 20), float32] */
     } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], 
Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
     %12(%x3, %10, %11) /* ty=Tensor[(10, 20), float32] */
   } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], 
Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), 
float32]) -> Tensor[(10, 20), float32] */
   ```
   
   ### Environment
   
   Linux, TVM mainline
   
   ### Steps to reproduce
   You can use the test from this commit: 
https://github.com/echuraev/tvm/commit/88f2d4b30dba5145b33947d9a2e81cbab0d1b8a3
   
   ### Triage
   
   * needs-triage
   * flow:relay
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to