Yongqi-Zhuo opened a new issue, #17165:
URL: https://github.com/apache/tvm/issues/17165

   Generally `reshape` is a no-op, and should be eliminated if possible. 
According to 
https://discuss.tvm.apache.org/t/te-using-reshape-without-copy/9480 and 
https://discuss.tvm.apache.org/t/can-tvm-do-reshape-without-any-extra-effort/1333,
 it is fine for a single `reshape` to be kept until fused with another op. 
However, in my own setup, TVM failed to fuse `reshape` with another op such as 
`adaptive_avg_pool2d`.
   
   Here is a minimal example.
   
   ```python
   import tvm
   from tvm import relax
   from tvm.script import ir as I
   from tvm.script import tir as T
   from tvm.script import relax as R
   
   @I.ir_module
   class Module:
       @R.function
       def main(inp_0: R.Tensor((1, 512, 7, 7), dtype="float32")) -> 
R.Tensor((1, 512), dtype="float32"):
           with R.dataflow():
               lv: R.Tensor((1, 512, 1, 1), dtype="float32") = 
R.nn.adaptive_avg_pool2d(inp_0, output_size=[1, 1], layout="NCHW", 
out_layout="NCHW")
               gv: R.Tensor((1, 512), dtype="float32") = R.reshape(lv, 
R.shape([1, 512]))
               R.output(gv)
           return gv
   
   if __name__ == "__main__":
       mod = Module
       print("Before fusion:")
       mod.show()
       seq = tvm.transform.Sequential([
           relax.transform.LegalizeOps(enable_warning=True),
           relax.transform.AnnotateTIROpPattern(),
           relax.transform.FuseOps(),
           relax.transform.FuseTIR(),
           relax.transform.DeadCodeElimination(),
       ])
       with tvm.transform.PassContext(opt_level=3, 
config={"relax.FuseOps.max_depth": 32}):
           mod = seq(mod)
       print("After fusion:")
       mod.show()
   ```
   
   This example was extracted from ResNet-18. In this example, there are only 2 
ops, i.e., avg pool ((1, 512, 7, 7) -> (1, 512, 1, 1)) and a no-op reshape ((1, 
512, 1, 1) -> (1, 512)). It should be expected that the reshape should get 
fused into the avg pool but the output is as follows
   
   ```
   Before fusion:
   # from tvm.script import ir as I
   # from tvm.script import relax as R
   
   @I.ir_module
   class Module:
       @R.function
       def main(inp_0: R.Tensor((1, 512, 7, 7), dtype="float32")) -> 
R.Tensor((1, 512), dtype="float32"):
           with R.dataflow():
               lv: R.Tensor((1, 512, 1, 1), dtype="float32") = 
R.nn.adaptive_avg_pool2d(inp_0, output_size=[1, 1], layout="NCHW", 
out_layout="NCHW")
               gv: R.Tensor((1, 512), dtype="float32") = R.reshape(lv, 
R.shape([1, 512]))
               R.output(gv)
           return gv
   
   After fusion:
   # from tvm.script import ir as I
   # from tvm.script import tir as T
   # from tvm.script import relax as R
   
   @I.ir_module
   class Module:
       @T.prim_func(private=True)
       def adaptive_avg_pool2d(inp_0: T.Buffer((T.int64(1), T.int64(512), 
T.int64(7), T.int64(7)), "float32"), adaptive_pool_avg: T.Buffer((T.int64(1), 
T.int64(512), T.int64(1), T.int64(1)), "float32")):
           T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
           # with T.block("root"):
           adaptive_pool_sum = T.alloc_buffer((T.int64(1), T.int64(512), 
T.int64(1), T.int64(1)))
           for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(1), T.int64(512), 
T.int64(1), T.int64(1), T.int64(7), T.int64(7)):
               with T.block("adaptive_pool_sum"):
                   v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = 
T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
                   T.reads(inp_0[v_ax0, v_ax1, v_ax2 * T.int64(7) + v_rv0, 
v_ax3 * T.int64(7) + v_rv1])
                   T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                   with T.init():
                       adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = 
T.float32(0)
                   adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = 
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + inp_0[v_ax0, v_ax1, v_ax2 * 
T.int64(7) + v_rv0, v_ax3 * T.int64(7) + v_rv1]
           for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(512), 
T.int64(1), T.int64(1)):
               with T.block("adaptive_pool_avg"):
                   v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, 
ax2, ax3])
                   T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                   T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
                   T.block_attr({"schedule_rule": 
"meta_schedule.adaptive_pool_avg"})
                   adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = 
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.020408163265306121)
   
       @T.prim_func(private=True)
       def reshape(lv: T.Buffer((T.int64(1), T.int64(512), T.int64(1), 
T.int64(1)), "float32"), T_reshape: T.Buffer((T.int64(1), T.int64(512)), 
"float32")):
           T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
           # with T.block("root"):
           for ax0, ax1 in T.grid(T.int64(1), T.int64(512)):
               with T.block("T_reshape"):
                   v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                   T.reads(lv[T.int64(0), v_ax1 % T.int64(512), T.int64(0), 
T.int64(0)])
                   T.writes(T_reshape[v_ax0, v_ax1])
                   T_reshape[v_ax0, v_ax1] = lv[T.int64(0), v_ax1 % 
T.int64(512), T.int64(0), T.int64(0)]
   
       @R.function
       def main(inp_0: R.Tensor((1, 512, 7, 7), dtype="float32")) -> 
R.Tensor((1, 512), dtype="float32"):
           cls = Module
           with R.dataflow():
               lv = R.call_tir(cls.adaptive_avg_pool2d, (inp_0,), 
out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"))
               gv = R.call_tir(cls.reshape, (lv,), out_sinfo=R.Tensor((1, 512), 
dtype="float32"))
               R.output(gv)
           return gv
   ```
   
   ### Expected behavior
   
   The `reshape` should have been fused. An expected possible `PrimFunc` would 
be
   
   ```python
       def expected_fused(inp_0: T.Buffer((T.int64(1), T.int64(512), 
T.int64(7), T.int64(7)), "float32"), adaptive_pool_avg: T.Buffer((T.int64(1), 
T.int64(512)), "float32")):
           # with T.block("root"):
           adaptive_pool_sum = T.alloc_buffer((T.int64(1), T.int64(512), 
T.int64(1), T.int64(1)))
           for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(1), T.int64(512), 
T.int64(1), T.int64(1), T.int64(7), T.int64(7)):
               with T.block("adaptive_pool_sum"):
                   v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = 
T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
                   T.reads(inp_0[v_ax0, v_ax1, v_ax2 * T.int64(7) + v_rv0, 
v_ax3 * T.int64(7) + v_rv1])
                   T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                   with T.init():
                       adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = 
T.float32(0)
                   adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = 
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + inp_0[v_ax0, v_ax1, v_ax2 * 
T.int64(7) + v_rv0, v_ax3 * T.int64(7) + v_rv1]
           for ax0, ax1 in T.grid(T.int64(1), T.int64(512)):
               with T.block("adaptive_pool_avg"):
                   v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                   T.reads(adaptive_pool_sum[v_ax0, v_ax1, T.int64(0), 
T.int64(0)])
                   T.writes(adaptive_pool_avg[v_ax0, v_ax1])
                   adaptive_pool_avg[v_ax0, v_ax1] = adaptive_pool_sum[v_ax0, 
v_ax1, T.int64(0), T.int64(0)] * T.float32(0.020408163265306121)
   ```
   
   ### Actual behavior
   
   The two ops cannot be fused, as shown in the output. This means, TVM wastes 
time copying a buffer to another buffer with identical shape.
   
   ### Environment
   
   Any environment details, such as: Operating System, TVM version, etc
   
   OS: Arch Linux, kernel version 6.9.9
   
   TVM: commit 292ecfd21031eef97d8750d553a3cf65c74ecaf8, Jun 14 2024
   
   ### Steps to reproduce
   
   The script provided above.
   
   ### Triage
   
   Please refer to the list of label tags 
[here](https://github.com/apache/tvm/wiki/Issue-Triage-Labels) to find the 
relevant tags and add them below in a bullet format (example below).
   
   * tir:transform
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to