Lunderberg commented on PR #16315:
URL: https://github.com/apache/tvm/pull/16315#issuecomment-1887437915

   This pass would be part of an optimization pipeline that could be used for 
batched LoRA models.  In this usage, the `R.take(weights, indices)` would 
select which LoRA should be used for each prompt in the batch.  By rearranging 
the matmul, the  resulting `R.matmul(x, weights)` could be combined with 
`R.matmul(x, base_weights)` using `CombineParallelMatmul`.
   
   While this individual pass does increase the total amount of compute, the 
overall pipeline should (dependent on model size, number of loras, and batch 
size) improve performance by changing from three matmuls (one large matmul of 
the base weights and two small matmuls with the LoRA components) to two matmuls 
(one large matmul of `concat(base_weights, lora_a)` and one small matmul with 
`lora_b`).


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to