masahi commented on pull request #7441: URL: https://github.com/apache/tvm/pull/7441#issuecomment-784615965
@ymwangg @codeislife99 I found a neat trick PyTorch uses for `count`. https://github.com/pytorch/pytorch/blob/22a34bcf4e5eaa348f0117c414c3dd760ec64b13/aten/src/ATen/native/cuda/Unique.cu#L60-L68 Basically, after you get ex scan, instead of copy original inputs, you copy from an array [0, 1, 2, ....]. This will give you something like [0, 2, 5], and doing adjacent element on it directly gives the count. Does this make sense? It should be much faster than atomic. PyTorch uses a separate `unique_by_key` call to compute counts, but since we have ex scan outputs lying around, we don't need this separate call. So we can be faster than PyTorch. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
