masahi opened a new pull request, #14669: URL: https://github.com/apache/tvm/pull/14669
Currently, `DataflowReshapeRewrite` expects that `call_tir(reshape, (...))` gets only one input: https://github.com/apache/tvm/blob/unity/src/relax/transform/rewrite_dataflow_reshape.cc#L66-L67 This makes sense, but I met a case where this assumption doesn't hold. Consider the following subgraph: ``` split / | \ / | \ / | \ / | \ tuple get tuple get tuple get | | | | | | reshape reshape reshape \ | / \ | / \ | / \ | / attention ``` Since `split` and `attention` cannot be fused, each branch of `tuple get -> reshape` becomes a fused function whose input is a tuple. In `FuseTIR`, this tuple input is flattened, so the fused tir `reshape` function gets three buffers as input, even though only one of them is actually reshaped. This is what it looks like after `FuseTIR`: ``` lv47_3 = R.call_tir(cls.split1, (lv46_2,), out_sinfo=[R.Tensor((2, 4096, 320), dtype="float16"), R.Tensor((2, 4096, 320), dtype="float16"), R.Tensor((2, 4096, 320), dtype="float16")]) lv1579: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[0] lv1580: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[1] lv1581_1: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[2] lv729_1 = R.call_tir(cls.fused_reshape5_cast17, (lv1579, lv1580, lv1581_1), out_sinfo=R.Tensor((2, 4096, 8, 40), dtype="float32")) lv1582: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[0] lv1583: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[1] lv1584: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[2] lv730_1 = R.call_tir(cls.fused_reshape5_cast171, (lv1582, lv1583, lv1584), out_sinfo=R.Tensor((2, 4096, 8, 40), dtype="float32")) lv1585: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[0] lv1586: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[1] lv1587: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[2] lv731_1 = R.call_tir(cls.fused_reshape5_cast172, (lv1585, lv1586, lv1587), out_sinfo=R.Tensor((2, 4096, 8, 40), dtype="float32")) ``` `DataflowReshapeRewrite` breaks on such mod, since `reshape` function is getting three inputs. Now, the TIR mod above looks odd but it is functionally correct. I wanted to change `FuseTIR` to emit only used buffers when flattening the tuple, but that seems complicated. Modifying `DataflowReshapeRewrite` to accept such mod is a bit hacky but a simple solution. @MasterJH5574 @Hzfengsy -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
