masahi opened a new pull request, #14669:
URL: https://github.com/apache/tvm/pull/14669

   Currently, `DataflowReshapeRewrite` expects that `call_tir(reshape, (...))` 
gets only one input: 
   
   
https://github.com/apache/tvm/blob/unity/src/relax/transform/rewrite_dataflow_reshape.cc#L66-L67
   
   This makes sense, but I met a case where this assumption doesn't hold. 
Consider the following subgraph:
   
   ```
                split
            /     |     \
           /      |      \
          /       |       \
         /        |        \
    tuple get  tuple get  tuple get
        |         |         |
        |         |         |
    reshape    reshape    reshape
         \        |         /
          \       |        /
           \      |       /
            \     |      /
              attention
   ```       
   
   Since `split` and `attention` cannot be fused, each branch of `tuple get -> 
reshape` becomes a fused function whose input is a tuple. In `FuseTIR`, this 
tuple input is flattened, so the fused tir `reshape` function gets three 
buffers as input, even though only one of them is actually reshaped. This is 
what it looks like after `FuseTIR`:
   
   ```
   lv47_3 = R.call_tir(cls.split1, (lv46_2,), out_sinfo=[R.Tensor((2, 4096, 
320), dtype="float16"), R.Tensor((2, 4096, 320), dtype="float16"), R.Tensor((2, 
4096, 320), dtype="float16")])
   
   lv1579: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[0]
   lv1580: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[1]
   lv1581_1: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[2]
   lv729_1 = R.call_tir(cls.fused_reshape5_cast17, (lv1579, lv1580, lv1581_1), 
out_sinfo=R.Tensor((2, 4096, 8, 40), dtype="float32"))
   
   lv1582: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[0]
   lv1583: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[1]
   lv1584: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[2]
   lv730_1 = R.call_tir(cls.fused_reshape5_cast171, (lv1582, lv1583, lv1584), 
out_sinfo=R.Tensor((2, 4096, 8, 40), dtype="float32"))
   
   lv1585: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[0]
   lv1586: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[1]
   lv1587: R.Tensor((2, 4096, 320), dtype="float16") = lv47_3[2]
   lv731_1 = R.call_tir(cls.fused_reshape5_cast172, (lv1585, lv1586, lv1587), 
out_sinfo=R.Tensor((2, 4096, 8, 40), dtype="float32"))
   ```
   
   `DataflowReshapeRewrite` breaks on such mod, since `reshape` function is 
getting three inputs. 
   
   Now, the TIR mod above looks odd but it is functionally correct. I wanted to 
change `FuseTIR` to emit only used buffers when flattening the tuple, but that 
seems complicated. Modifying `DataflowReshapeRewrite` to accept such mod is a 
bit hacky but a simple solution. 
   
   
   @MasterJH5574 @Hzfengsy 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to