echuraev opened a new issue, #15358: URL: https://github.com/apache/tvm/issues/15358
While working on the PR #15137, I have found that FuseOps pass fuses operations in none intuitive way and sometimes ops don't fuse together and remain individual PrimFuncs. On this [branch](https://github.com/echuraev/tvm/tree/echuraev/fuse_ops_issue) I prepared a test script which reproduces this situation. We have a base function which contains one operation: ``` fn (%x0: Tensor[(10, 20), float32]) { multiply(%x0, 2f); } ``` In the test we create 5 such operations and compute their sum (`base_func0 + base_func1 + ... + base_func4`). This Relay function looks in the following way: ``` fn (%x0: Tensor[(10, 20), float32], %x1: Tensor[(10, 20), float32], %x2: Tensor[(10, 20), float32], %x3: Tensor[(10, 20), float32], %x4: Tensor[(10, 20), float32]) { %0 = multiply(%x0, 2f); %1 = multiply(%x1, 2f); %2 = add(%0, %1); %3 = multiply(%x2, 2f); %4 = add(%2, %3); %5 = multiply(%x3, 2f); %6 = add(%4, %5); %7 = multiply(%x4, 2f); add(%6, %7) } ``` We want to specify fusing depth that each PrimFunc will contain maximum two base functions. In this case `max_fused_ops = (base_function_ops + 1) * number_of_fused_base_func`, where `base_function_ops = 1` is the number of operations in base func, `number_of_fused_base_func = 2` is the maximum number of base functions in one PrimFunc. In the formula we add one to `base_function_ops`, because if we want to fuse `N` base functions into one function, then for each base function we will have additionally `N-1` `add` operations and `+1` `add` operation for the previous result. ### Expected behavior After fusing algorithm, I expected to see the code that fuse several base functions into one PrimFunc. E.g. in the code below, 5 base functions were fused into 3 PrimFuncs. The first and the second PrimFuncs contains 4 base functions, and the last PrimFunc computes the result of computation for 4 base functions with the last one. ``` fn (%x0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x1: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x01: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x2: Tensor[(10, 20), float32] /* ty=Tensor[(10, 2 0), float32] */, %x02: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */) -> Tensor[(10, 20), float32] { %6 = fn (%p02: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p12: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] { %4 = multiply(%p02, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */; %5 = multiply(%p12, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */; add(%4, %5) /* ty=Tensor[(10, 20), float32] */ } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */; %7 = %6(%x0, %x1) /* ty=Tensor[(10, 20), float32] */; %8 = fn (%p01: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p11: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p2: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] { %1 = multiply(%p01, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */; %2 = add(%p11, %1) /* ty=Tensor[(10, 20), float32] */; %3 = multiply(%p2, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */; add(%2, %3) /* ty=Tensor[(10, 20), float32] */ } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */; %9 = %8(%x01, %7, %x2) /* ty=Tensor[(10, 20), float32] */; %10 = fn (%p0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p1: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] { %0 = multiply(%p0, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */; add(%p1, %0) /* ty=Tensor[(10, 20), float32] */ } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */; %10(%x02, %9) /* ty=Tensor[(10, 20), float32] */ } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */ ``` ### Actual behavior Actually, with the current FuseOps pass implementation, I see a bit different code. There are 5 PrimFuncs and each of these PrimFuncs contains one base function. I suppose it is an incorrect behavior, please correct me if I'm wrong. ``` fn (%x0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x1: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x2: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %x3: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20 ), float32] */, %x4: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */) -> Tensor[(10, 20), float32] { %4 = fn (%p02: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] { multiply(%p02, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */ } /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */; %5 = fn (%p03: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] { multiply(%p03, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */ } /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */; %6 = %4(%x1) /* ty=Tensor[(10, 20), float32] */; %7 = %5(%x2) /* ty=Tensor[(10, 20), float32] */; %8 = fn (%p01: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p11: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p21: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] { %2 = multiply(%p01, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */; %3 = add(%2, %p11) /* ty=Tensor[(10, 20), float32] */; add(%3, %p21) /* ty=Tensor[(10, 20), float32] */ } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */; %9 = fn (%p04: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] { multiply(%p04, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */ } /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */; %10 = %8(%x0, %6, %7) /* ty=Tensor[(10, 20), float32] */; %11 = %9(%x4) /* ty=Tensor[(10, 20), float32] */; %12 = fn (%p0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p1: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, %p2: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] { %0 = multiply(%p0, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */; %1 = add(%p1, %0) /* ty=Tensor[(10, 20), float32] */; add(%1, %p2) /* ty=Tensor[(10, 20), float32] */ } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */; %12(%x3, %10, %11) /* ty=Tensor[(10, 20), float32] */ } /* ty=fn (Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32], Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */ ``` ### Environment Linux, TVM mainline ### Steps to reproduce You can use the test from this commit: https://github.com/echuraev/tvm/commit/88f2d4b30dba5145b33947d9a2e81cbab0d1b8a3 ### Triage * needs-triage * flow:relay -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
