kpuatamazon opened a new pull request #19601: URL: https://github.com/apache/incubator-mxnet/pull/19601
## Description ## Adds a CPU kernel for LayerNorm that handles the common case of axis = -1. This is based upon the implementation from Marian at https://github.com/marian-nmt/marian-dev/blob/3b468e462809fe42a01a717c8d9307c465e6c35e/src/tensors/cpu/tensor_operators.cpp#L1047-L1087 . Compared to the MXNet-internal generic implementation, the kernel is 1.6-29x faster. When used in Sockeye, end-to-end translation is 14%. Compared to the MKL implementation, the kernel is 0.9-2.28x faster. Marian's is faster than MKL for all channels tested wider than 32. ## Checklist ## ### Essentials ### - [x] PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], [FEATURE], [DOC], etc) - [x] Changes are complete (i.e. I finished coding on this PR) - [x] All changes have test coverage. There's already a `test_operator.py:test_layer_norm` that covers this well and it passes. - [x] Code is well-documented---more documented than the baseline ### Changes ### - [x] Copy Marian optimized CPU LayerNorm implementation and adapt to MXNet. - [x] Refactor dispatch of optimized versions using bool return value. ## Benchmarks ## ### Speed ### - Shapes borrowed from #14935 - c5.12xlarge - Based on db080058fdd428865b30077eb883a5987208d8b0 (v1.x) - Ubuntu 18 - `cmake -DCMAKE_BUILD_TYPE=Release -DUSE_MKLDNN=ON -DUSE_CUDA=OFF -DUSE_TVM_OP=OFF -DUSE_MKL_IF_AVAILABLE=OFF -DCMAKE_C_COMPILER=gcc-8 -DCMAKE_CXX_COMPILER=g++-8 -GNinja` except for the MKL case when `-DUSE_MKL_IF_AVAILABLE=ON` - MKL 20190005 when used. - Time in seconds. - `export OMP_NUM_THREADS=4` Benchmark program ```python3 #!/usr/bin/env python3 import mxnet as mx import time def time_procedure(shape, count): data = mx.nd.random_uniform(shape=shape, low=-1.0, high = 1.0) factors = mx.nd.random_uniform(shape=(shape[-1],)) mx.nd.waitall() begin = time.time() for i in range(0, count): out = mx.nd.LayerNorm(data, factors, factors) mx.nd.waitall() return (time.time() - begin) / count count = 200 for channel in [32, 64, 128, 256, 512, 768, 1024]: for batch in [1, 128, 2560, 4096, 8192, 16384]: s = (batch, channel) timing = time_procedure(s, count) print("{:5d}x{:5d} | {:.7f}".format(s[0], s[1], timing)) ``` Here are the results (in seconds). Yes, I included first run. Make your JIT faster. | Shape | Marian | MKL | MXNet Generic | Marian speedup v MKL | Marian speedup v MXNet | |--|--|--|--|--|--| | 1x 32 | 0.0000254| 0.0000267| 0.0000409|1.05x | 1.61x | | 128x 32 | 0.0000318| 0.0000308| 0.0000632|0.97x | 1.99x | | 2560x 32 | 0.0000690| 0.0000679| 0.0004944|0.98x | 7.17x | | 4096x 32 | 0.0000952| 0.0000907| 0.0007636|0.95x | 8.02x | | 8192x 32 | 0.0001591| 0.0001503| 0.0015753|0.94x | 9.90x | |16384x 32 | 0.0002900| 0.0002633| 0.0030074|0.91x | 10.37x | | 1x 64 | 0.0000240| 0.0000249| 0.0000399|1.04x | 1.66x | | 128x 64 | 0.0000311| 0.0000327| 0.0000837|1.05x | 2.69x | | 2560x 64 | 0.0000826| 0.0000984| 0.0009193|1.19x | 11.13x | | 4096x 64 | 0.0001142| 0.0001366| 0.0015389|1.20x | 13.48x | | 8192x 64 | 0.0001985| 0.0002446| 0.0029263|1.23x | 14.74x | |16384x 64 | 0.0003815| 0.0004561| 0.0056857|1.20x | 14.90x | | 1x 128 | 0.0000243| 0.0000254| 0.0000401|1.05x | 1.65x | | 128x 128 | 0.0000342| 0.0000397| 0.0001280|1.16x | 3.74x | | 2560x 128 | 0.0001063| 0.0001594| 0.0018591|1.50x | 17.49x | | 4096x 128 | 0.0001501| 0.0002355| 0.0028828|1.57x | 19.21x | | 8192x 128 | 0.0002695| 0.0004378| 0.0055950|1.62x | 20.76x | |16384x 128 | 0.0005846| 0.0008852| 0.0110546|1.51x | 18.91x | | 1x 256 | 0.0000252| 0.0000272| 0.0000424|1.08x | 1.68x | | 128x 256 | 0.0000381| 0.0000446| 0.0002133|1.17x | 5.60x | | 2560x 256 | 0.0001542| 0.0002870| 0.0035257|1.86x | 22.86x | | 4096x 256 | 0.0002241| 0.0004369| 0.0055310|1.95x | 24.68x | | 8192x 256 | 0.0005067| 0.0008487| 0.0109084|1.67x | 21.53x | |16384x 256 | 0.0011817| 0.0017543| 0.0217319|1.48x | 18.39x | | 1x 512 | 0.0000262| 0.0000306| 0.0000475|1.17x | 1.81x | | 128x 512 | 0.0000405| 0.0000549| 0.0003818|1.36x | 9.43x | | 2560x 512 | 0.0002462| 0.0005229| 0.0068302|2.12x | 27.74x | | 4096x 512 | 0.0003823| 0.0008172| 0.0108432|2.14x | 28.36x | | 8192x 512 | 0.0008764| 0.0017205| 0.0216015|1.96x | 24.65x | |16384x 512 | 0.0057181| 0.0072662| 0.0464290|1.27x | 8.12x | | 1x 768 | 0.0000274| 0.0000309| 0.0000519|1.13x | 1.89x | | 128x 768 | 0.0000439| 0.0000675| 0.0005498|1.54x | 12.52x | | 2560x 768 | 0.0003469| 0.0007757| 0.0101437|2.24x | 29.24x | | 4096x 768 | 0.0005857| 0.0013381| 0.0161946|2.28x | 27.65x | | 8192x 768 | 0.0014930| 0.0026524| 0.0322792|1.78x | 21.62x | |16384x 768 | 0.0088047| 0.0110582| 0.0698267|1.26x | 7.93x | | 1x 1024 | 0.0000275| 0.0000330| 0.0000573|1.20x | 2.08x | | 128x 1024 | 0.0000486| 0.0000790| 0.0007189|1.63x | 14.79x | | 2560x 1024 | 0.0004582| 0.0010214| 0.0135037|2.23x | 29.47x | | 4096x 1024 | 0.0008070| 0.0017359| 0.0215496|2.15x | 26.70x | | 8192x 1024 | 0.0057007| 0.0073134| 0.0463280|1.28x | 8.13x | |16384x 1024 | 0.0116098| 0.0147560| 0.0935520|1.27x | 8.06x | ### AWS Sockeye ### Observed a 14% speed up in end-to-end machine translation with Sockeye. Sockeye 2.2 (29795b82) on a c5.12xlarge with `export OMP_NUM_THREADS=4` translating a test set. Compiled on Ubuntu 18 with `cmake -DCMAKE_BUILD_TYPE=Release -DUSE_MKLDNN=ON -DUSE_CUDA=OFF -DUSE_TVM_OP=OFF -DUSE_MKL_IF_AVAILABLE=OFF -DCMAKE_C_COMPILER=gcc-8 -DCMAKE_CXX_COMPILER=g++-8 -GNinja ..` Note: no MKL. Before ``` [INFO:__main__] Processed 2964 lines. Total time: 133.3097, sec/sent: 0.0450, sent/sec: 22.2339 real 2m15.716s user 9m52.988s sys 0m13.504s ``` After ``` [INFO:__main__] Processed 2964 lines. Total time: 116.6679, sec/sent: 0.0394, sent/sec: 25.4054 real 1m58.858s user 8m45.803s sys 0m13.823s ``` The above runs were done as normal, without the profiler. I then turned the profiler on. We can see that LayerNorm is consuming a substantial amount of time: Before ``` operator ================= Name Total Count Time (ms) Min Time (ms) Max Time (ms) Avg Time (ms) ---- ----------- --------- ------------- ------------- ------------- _contrib_intgemm_fully_connected 822520 26357.8887 0.0090 0.3390 0.0320 LayerNorm 459522 20225.8086 0.0230 0.4860 0.0440 elemwise_add 601340 7813.2148 0.0040 0.1970 0.0130 _contrib_interleaved_matmul_encdec_qk 155884 7557.1152 0.0050 0.3560 0.0485 _contrib_interleaved_matmul_encdec_valatt 155884 6168.3472 0.0040 0.4120 0.0396 FullyConnected 48262 4070.1250 0.0260 4.7480 0.0843 DeleteVariable 1577462 3830.7241 0.0000 0.3660 0.0024 Concat 107622 3493.2380 0.0100 0.2970 0.0325 take 386096 3484.5449 0.0020 1.5600 0.0090 SliceChannel 65296 3468.1431 0.0060 0.4370 0.0531 where 144786 3203.5801 0.0030 0.2090 0.0221 Activation 252408 3095.2820 0.0060 0.1750 0.0123 ``` After ``` operator ================= Name Total Count Time (ms) Min Time (ms) Max Time (ms) Avg Time (ms) ---- ----------- --------- ------------- ------------- ------------- _contrib_intgemm_fully_connected 822316 25351.8438 0.0090 0.4190 0.0308 elemwise_add 601170 8229.7861 0.0040 0.1650 0.0137 _contrib_interleaved_matmul_encdec_qk 155850 7577.9399 0.0050 0.4030 0.0486 _contrib_interleaved_matmul_encdec_valatt 155850 6169.1318 0.0040 0.4310 0.0396 FullyConnected 48245 4170.0972 0.0240 4.8480 0.0864 DeleteVariable 1576986 3935.9939 0.0000 0.3490 0.0025 take 385960 3624.0161 0.0020 2.4180 0.0094 Concat 107605 3561.9041 0.0100 0.3540 0.0331 SliceChannel 65296 3475.8010 0.0060 0.5690 0.0532 where 144735 3241.1169 0.0030 0.2080 0.0224 Activation 252340 2855.7710 0.0050 0.2440 0.0113 LayerNorm 459403 2791.0029 0.0040 0.0540 0.0061 ``` The new implementation is 7.21x as fast on average according to the profiler. The number of LayerNorm invocations changes 0.02% because beam search iterations are impacted by tie breaking. ### Unit test ### Before: 62.210s After: 61.321s But note unit tests spend most of their time comparing things rather than running the kernels. ## Comments ## - LayerNorm is just one of those kernels that changes slightly with any implementation, so outputs that depend on near-ties will change. - float16 support on CPU accumulates in float32. Since float16 only exists in conversion on CPU, this is faster anyway. Also, there wasn't an OMP reduction for float16. - There is no threaded parallelization within a channel (i.e. to sum). I doubt existing channel sizes justify this given the cost of threading on CPUs. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
