masahi opened a new pull request #7195: URL: https://github.com/apache/tvm/pull/7195
Current implementation of thrust argsort, when given multi dimensional inputs to sort along the inner most axis, is very inefficient: it does `n_iter` calls to thrust sort. See https://github.com/apache/tvm/blob/bad149ed8a555444d813537608ee5cea9e95e97e/src/runtime/contrib/thrust/thrust.cu#L50-L65 When the outer dimension is large, the performance of thrust argsort is far from optimal. In particular, the thrust numbers in shown in https://github.com/apache/tvm/pull/7099 do not reflect the true performance thrust can achieve. This PR replaces `n_iter` calls to thrust argsort with one segmented sort by key. Since thrust doesn't provide API to do sort by key, I used a neat back-to-back stable-sort-by-key trick explained in https://groups.google.com/forum/#!topic/thrust-users/BoLsxO6b4FY. My implementation is a bit more complicated because we need to do segmented sort **by key**, not just segmented sort. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org