masahi edited a comment on pull request #7441: URL: https://github.com/apache/tvm/pull/7441#issuecomment-784615965
@ymwangg @codeislife99 I found a neat trick PyTorch uses for `count`. https://github.com/pytorch/pytorch/blob/22a34bcf4e5eaa348f0117c414c3dd760ec64b13/aten/src/ATen/native/cuda/Unique.cu#L60-L68 Basically, after you get ex scan, instead of copying from the original input, you copy from an array [0, 1, 2, ....]. This will give you something like [0, 2, 5], and doing adjacent element on it directly gives the count. Does this make sense? It should be much faster than atomic. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
