Lunderberg opened a new pull request, #16596:
URL: https://github.com/apache/tvm/pull/16596

   This commit implements an optional optimization pass 
`relax.transform.ReorderPermuteDimsAfterConcat`, which reorder expressions of 
the form `R.concat(R.permute_dims(A), R.permute_dims(B))` into 
`R.permute_dims(R.concat(A,B))`.
   
   This pass is intended to be used alongside `CombineParallelMatmul`. After 
parallel matmuls are combined, to be lifted out, and optimized `nn.Linear` 
kernels to find the `R.matmul(x, R.permute_dims(weights))` patterns they are 
looking for.
   
   ```python
   @R.function
   def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, 
weight_value: R.Tensor):
       """Initial IRModule
   
       The `R.permute_dims` followed by `R.matmul` is the relax
       equivalent of `nn.Linear`, and will frequently have optimized
       kernels.
       """
       weight_query_T = R.permute_dims(weight_query)
       query = R.matmul(x, weight_query)
       weight_key_T = R.permute_dims(weight_key)
       key = R.matmul(x, weight_key)
       weight_value_T = R.permute_dims(weight_value)
       value = R.matmul(x, weight_value)
   
   @R.function
   def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, 
weight_value: R.Tensor):
       """After `CombineParallelMatmul`
   
       There's now only a single matmul to be performed, which is
       generally better than performing three small matmuls.  However,
       the optimized kernels for `nn.Linear` can no longer be applied,
       because the `R.concat` isn't part of the expected pattern.
       """
       weight_query_T = R.permute_dims(weight_query)
       weight_key_T = R.permute_dims(weight_key)
       weight_value_T = R.permute_dims(weight_value)
   
       fused_weight_T = R.concat([weight_query_T, weight_key_T, 
weight_value_T], axis=1)
       fused_qkv = R.matmul(x, fused_weight_T)
   
       query, key, value = R.split(fused_qkv)
   
   @R.function
   def func(x: R.Tensor, weight_query: R.Tensor, weight_key: R.Tensor, 
weight_value: R.Tensor):
       """After `ReorderPermuteDimsAfterConcat`
   
       There's still only a single matmul, and the optimized kernels for
       `nn.Linear` can be applied again.
       """
       fused_weight = R.concat([weight_query, weight_key, weight_value], axis=0)
   
       fused_weight_T = R.permute_dims(fused_weight)
       fused_qkv = R.matmul(x, fused_weight_T)
   
       query, key, value = R.split(fused_qkv)
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to