masahi commented on pull request #7441:
URL: https://github.com/apache/tvm/pull/7441#issuecomment-784615965


   @ymwangg @codeislife99 I found a neat trick PyTorch uses for `count`. 
https://github.com/pytorch/pytorch/blob/22a34bcf4e5eaa348f0117c414c3dd760ec64b13/aten/src/ATen/native/cuda/Unique.cu#L60-L68
   
   Basically, after you get ex scan, instead of copy original inputs, you copy 
from an array [0, 1, 2, ....]. This will give you something like [0, 2, 5], and 
doing adjacent element on it directly gives the count. Does this make sense? It 
should be much faster than atomic.
   
   PyTorch uses a separate `unique_by_key` call to compute counts, but since we 
have ex scan outputs lying around, we don't need this separate call. So we can 
be faster than PyTorch.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to