ptrendx opened a new pull request #19905: URL: https://github.com/apache/incubator-mxnet/pull/19905
## Description ## This PR moves the GPU softmax implementation (not yet the masked softmax implementation) to use RTC and adds multiple optimizations to it to improve performance. ## Checklist ## ### Essentials ### - [x] PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], [FEATURE], [DOC], etc) - [x] Changes are complete (i.e. I finished coding on this PR) - [x] All changes have test coverage - [x] Code is well-documented ### Changes ### - [x] Moved both stride1 and non-stride1 versions of the softmax kernels to use RTC - [x] The performance of the non-stride1 version was improved by running multiple rows per block and coalescing memory accesses. Benchmarks show ~4x improvement in time for the typical case, and much more (up to ~40x) when the size of the row over which the summation happens is very small. - [x] The performance of the stride1 kernel was improved by downloading multiple rows to shared memory collectively by the entire block and increasing amount of work per thread (including ability for the entire row to be summed by even a single thread, down from the minimum of 1 full warp per block in the previous version). - [x] The vectorization requirements of the previous implementation were eliminated, resulting in especially big speedup for cases where row length is odd. - [x] The stride1 kernel can now be used when the type of the output does not match the type of input (e.g. float16 input, float32 output) - [x] Overall, the performance of the stride1 kernel got improved ranging form 1.1x for BERT-like shapes (12 * 32, 128, 128), ~2x for the typical sizes with even row length and ~4x for the typical sizes with odd row length, to >20x for sizes with very small row length. - [x] Performance improvements quoted in the previous points are for the forward pass, but backward has similar (albeit slightly smaller) performance improvements. - [x] Improved the mixed_type utility for RTC kernels (now one can use `type_util::mixed_type<DType, DType2>` instead of the previous verbose `typename type_util::mixed_type<DType, DType2>::type`, and arbitrary number of types can be passed as template arguments) ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
