Yongqi-Zhuo opened a new issue, #17165: URL: https://github.com/apache/tvm/issues/17165
Generally `reshape` is a no-op, and should be eliminated if possible. According to https://discuss.tvm.apache.org/t/te-using-reshape-without-copy/9480 and https://discuss.tvm.apache.org/t/can-tvm-do-reshape-without-any-extra-effort/1333, it is fine for a single `reshape` to be kept until fused with another op. However, in my own setup, TVM failed to fuse `reshape` with another op such as `adaptive_avg_pool2d`. Here is a minimal example. ```python import tvm from tvm import relax from tvm.script import ir as I from tvm.script import tir as T from tvm.script import relax as R @I.ir_module class Module: @R.function def main(inp_0: R.Tensor((1, 512, 7, 7), dtype="float32")) -> R.Tensor((1, 512), dtype="float32"): with R.dataflow(): lv: R.Tensor((1, 512, 1, 1), dtype="float32") = R.nn.adaptive_avg_pool2d(inp_0, output_size=[1, 1], layout="NCHW", out_layout="NCHW") gv: R.Tensor((1, 512), dtype="float32") = R.reshape(lv, R.shape([1, 512])) R.output(gv) return gv if __name__ == "__main__": mod = Module print("Before fusion:") mod.show() seq = tvm.transform.Sequential([ relax.transform.LegalizeOps(enable_warning=True), relax.transform.AnnotateTIROpPattern(), relax.transform.FuseOps(), relax.transform.FuseTIR(), relax.transform.DeadCodeElimination(), ]) with tvm.transform.PassContext(opt_level=3, config={"relax.FuseOps.max_depth": 32}): mod = seq(mod) print("After fusion:") mod.show() ``` This example was extracted from ResNet-18. In this example, there are only 2 ops, i.e., avg pool ((1, 512, 7, 7) -> (1, 512, 1, 1)) and a no-op reshape ((1, 512, 1, 1) -> (1, 512)). It should be expected that the reshape should get fused into the avg pool but the output is as follows ``` Before fusion: # from tvm.script import ir as I # from tvm.script import relax as R @I.ir_module class Module: @R.function def main(inp_0: R.Tensor((1, 512, 7, 7), dtype="float32")) -> R.Tensor((1, 512), dtype="float32"): with R.dataflow(): lv: R.Tensor((1, 512, 1, 1), dtype="float32") = R.nn.adaptive_avg_pool2d(inp_0, output_size=[1, 1], layout="NCHW", out_layout="NCHW") gv: R.Tensor((1, 512), dtype="float32") = R.reshape(lv, R.shape([1, 512])) R.output(gv) return gv After fusion: # from tvm.script import ir as I # from tvm.script import tir as T # from tvm.script import relax as R @I.ir_module class Module: @T.prim_func(private=True) def adaptive_avg_pool2d(inp_0: T.Buffer((T.int64(1), T.int64(512), T.int64(7), T.int64(7)), "float32"), adaptive_pool_avg: T.Buffer((T.int64(1), T.int64(512), T.int64(1), T.int64(1)), "float32")): T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)}) # with T.block("root"): adaptive_pool_sum = T.alloc_buffer((T.int64(1), T.int64(512), T.int64(1), T.int64(1))) for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(7), T.int64(7)): with T.block("adaptive_pool_sum"): v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1]) T.reads(inp_0[v_ax0, v_ax1, v_ax2 * T.int64(7) + v_rv0, v_ax3 * T.int64(7) + v_rv1]) T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) with T.init(): adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0) adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + inp_0[v_ax0, v_ax1, v_ax2 * T.int64(7) + v_rv0, v_ax3 * T.int64(7) + v_rv1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(512), T.int64(1), T.int64(1)): with T.block("adaptive_pool_avg"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3]) T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"}) adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.020408163265306121) @T.prim_func(private=True) def reshape(lv: T.Buffer((T.int64(1), T.int64(512), T.int64(1), T.int64(1)), "float32"), T_reshape: T.Buffer((T.int64(1), T.int64(512)), "float32")): T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0, ax1 in T.grid(T.int64(1), T.int64(512)): with T.block("T_reshape"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(lv[T.int64(0), v_ax1 % T.int64(512), T.int64(0), T.int64(0)]) T.writes(T_reshape[v_ax0, v_ax1]) T_reshape[v_ax0, v_ax1] = lv[T.int64(0), v_ax1 % T.int64(512), T.int64(0), T.int64(0)] @R.function def main(inp_0: R.Tensor((1, 512, 7, 7), dtype="float32")) -> R.Tensor((1, 512), dtype="float32"): cls = Module with R.dataflow(): lv = R.call_tir(cls.adaptive_avg_pool2d, (inp_0,), out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32")) gv = R.call_tir(cls.reshape, (lv,), out_sinfo=R.Tensor((1, 512), dtype="float32")) R.output(gv) return gv ``` ### Expected behavior The `reshape` should have been fused. An expected possible `PrimFunc` would be ```python def expected_fused(inp_0: T.Buffer((T.int64(1), T.int64(512), T.int64(7), T.int64(7)), "float32"), adaptive_pool_avg: T.Buffer((T.int64(1), T.int64(512)), "float32")): # with T.block("root"): adaptive_pool_sum = T.alloc_buffer((T.int64(1), T.int64(512), T.int64(1), T.int64(1))) for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(7), T.int64(7)): with T.block("adaptive_pool_sum"): v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1]) T.reads(inp_0[v_ax0, v_ax1, v_ax2 * T.int64(7) + v_rv0, v_ax3 * T.int64(7) + v_rv1]) T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) with T.init(): adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0) adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + inp_0[v_ax0, v_ax1, v_ax2 * T.int64(7) + v_rv0, v_ax3 * T.int64(7) + v_rv1] for ax0, ax1 in T.grid(T.int64(1), T.int64(512)): with T.block("adaptive_pool_avg"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(adaptive_pool_sum[v_ax0, v_ax1, T.int64(0), T.int64(0)]) T.writes(adaptive_pool_avg[v_ax0, v_ax1]) adaptive_pool_avg[v_ax0, v_ax1] = adaptive_pool_sum[v_ax0, v_ax1, T.int64(0), T.int64(0)] * T.float32(0.020408163265306121) ``` ### Actual behavior The two ops cannot be fused, as shown in the output. This means, TVM wastes time copying a buffer to another buffer with identical shape. ### Environment Any environment details, such as: Operating System, TVM version, etc OS: Arch Linux, kernel version 6.9.9 TVM: commit 292ecfd21031eef97d8750d553a3cf65c74ecaf8, Jun 14 2024 ### Steps to reproduce The script provided above. ### Triage Please refer to the list of label tags [here](https://github.com/apache/tvm/wiki/Issue-Triage-Labels) to find the relevant tags and add them below in a bullet format (example below). * tir:transform -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
