Lunderberg commented on PR #16315: URL: https://github.com/apache/tvm/pull/16315#issuecomment-1887437915
This pass would be part of an optimization pipeline that could be used for batched LoRA models. In this usage, the `R.take(weights, indices)` would select which LoRA should be used for each prompt in the batch. By rearranging the matmul, the resulting `R.matmul(x, weights)` could be combined with `R.matmul(x, base_weights)` using `CombineParallelMatmul`. While this individual pass does increase the total amount of compute, the overall pipeline should (dependent on model size, number of loras, and batch size) improve performance by changing from three matmuls (one large matmul of the base weights and two small matmuls with the LoRA components) to two matmuls (one large matmul of `concat(base_weights, lora_a)` and one small matmul with `lora_b`). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
