kpuatamazon opened a new pull request #19601:
URL: https://github.com/apache/incubator-mxnet/pull/19601


   ## Description ##
   Adds a CPU kernel for LayerNorm that handles the common case of axis = -1.  
This is based upon the implementation from Marian at 
https://github.com/marian-nmt/marian-dev/blob/3b468e462809fe42a01a717c8d9307c465e6c35e/src/tensors/cpu/tensor_operators.cpp#L1047-L1087
 .  
   
   Compared to the MXNet-internal generic implementation, the kernel is 1.6-29x 
faster.  When used in Sockeye, end-to-end translation is 14%.
   Compared to the MKL implementation, the kernel is 0.9-2.28x faster.  
Marian's is faster than MKL for all channels tested wider than 32.  
   
   ## Checklist ##
   ### Essentials ###
   - [x] PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], 
[FEATURE], [DOC], etc)
   - [x] Changes are complete (i.e. I finished coding on this PR)
   - [x] All changes have test coverage.  There's already a 
`test_operator.py:test_layer_norm` that covers this well and it passes.  
   - [x] Code is well-documented---more documented than the baseline
   
   ### Changes ###
   - [x] Copy Marian optimized CPU LayerNorm implementation and adapt to MXNet.
   - [x] Refactor dispatch of optimized versions using bool return value.
   
   ## Benchmarks ##
   ### Speed ###
   
   - Shapes borrowed from #14935
   -  c5.12xlarge
   - Based on db080058fdd428865b30077eb883a5987208d8b0 (v1.x)
   - Ubuntu 18
   - `cmake -DCMAKE_BUILD_TYPE=Release -DUSE_MKLDNN=ON -DUSE_CUDA=OFF 
-DUSE_TVM_OP=OFF -DUSE_MKL_IF_AVAILABLE=OFF -DCMAKE_C_COMPILER=gcc-8 
-DCMAKE_CXX_COMPILER=g++-8 -GNinja` except for the MKL case when 
`-DUSE_MKL_IF_AVAILABLE=ON`
   - MKL 20190005 when used.
   - Time in seconds.
   - `export OMP_NUM_THREADS=4`
   
   Benchmark program
   ```python3
   #!/usr/bin/env python3
   import mxnet as mx
   import time
   
   def time_procedure(shape, count):
     data = mx.nd.random_uniform(shape=shape, low=-1.0, high = 1.0)
     factors = mx.nd.random_uniform(shape=(shape[-1],))
     mx.nd.waitall()
     begin = time.time()
     for i in range(0, count):
       out = mx.nd.LayerNorm(data, factors, factors)
       mx.nd.waitall()
     return (time.time() - begin) / count
   
   count = 200
   
   for channel in [32, 64, 128, 256, 512, 768, 1024]:
     for batch in [1, 128, 2560, 4096, 8192, 16384]:
       s = (batch, channel)
       timing = time_procedure(s, count)
       print("{:5d}x{:5d} | {:.7f}".format(s[0], s[1], timing))
   ```
   
   Here are the results (in seconds).  Yes, I included first run.  Make your 
JIT faster.  
   
   | Shape | Marian | MKL | MXNet Generic | Marian speedup v MKL | Marian 
speedup v MXNet |
   |--|--|--|--|--|--|
   |    1x   32 | 0.0000254| 0.0000267| 0.0000409|1.05x |  1.61x |
   |  128x   32 | 0.0000318| 0.0000308| 0.0000632|0.97x |  1.99x |
   | 2560x   32 | 0.0000690| 0.0000679| 0.0004944|0.98x |  7.17x |
   | 4096x   32 | 0.0000952| 0.0000907| 0.0007636|0.95x |  8.02x |
   | 8192x   32 | 0.0001591| 0.0001503| 0.0015753|0.94x |  9.90x |
   |16384x   32 | 0.0002900| 0.0002633| 0.0030074|0.91x | 10.37x |
   |    1x   64 | 0.0000240| 0.0000249| 0.0000399|1.04x |  1.66x |
   |  128x   64 | 0.0000311| 0.0000327| 0.0000837|1.05x |  2.69x |
   | 2560x   64 | 0.0000826| 0.0000984| 0.0009193|1.19x | 11.13x |
   | 4096x   64 | 0.0001142| 0.0001366| 0.0015389|1.20x | 13.48x |
   | 8192x   64 | 0.0001985| 0.0002446| 0.0029263|1.23x | 14.74x |
   |16384x   64 | 0.0003815| 0.0004561| 0.0056857|1.20x | 14.90x |
   |    1x  128 | 0.0000243| 0.0000254| 0.0000401|1.05x |  1.65x |
   |  128x  128 | 0.0000342| 0.0000397| 0.0001280|1.16x |  3.74x |
   | 2560x  128 | 0.0001063| 0.0001594| 0.0018591|1.50x | 17.49x |
   | 4096x  128 | 0.0001501| 0.0002355| 0.0028828|1.57x | 19.21x |
   | 8192x  128 | 0.0002695| 0.0004378| 0.0055950|1.62x | 20.76x |
   |16384x  128 | 0.0005846| 0.0008852| 0.0110546|1.51x | 18.91x |
   |    1x  256 | 0.0000252| 0.0000272| 0.0000424|1.08x |  1.68x |
   |  128x  256 | 0.0000381| 0.0000446| 0.0002133|1.17x |  5.60x |
   | 2560x  256 | 0.0001542| 0.0002870| 0.0035257|1.86x | 22.86x |
   | 4096x  256 | 0.0002241| 0.0004369| 0.0055310|1.95x | 24.68x |
   | 8192x  256 | 0.0005067| 0.0008487| 0.0109084|1.67x | 21.53x |
   |16384x  256 | 0.0011817| 0.0017543| 0.0217319|1.48x | 18.39x |
   |    1x  512 | 0.0000262| 0.0000306| 0.0000475|1.17x |  1.81x |
   |  128x  512 | 0.0000405| 0.0000549| 0.0003818|1.36x |  9.43x |
   | 2560x  512 | 0.0002462| 0.0005229| 0.0068302|2.12x | 27.74x |
   | 4096x  512 | 0.0003823| 0.0008172| 0.0108432|2.14x | 28.36x |
   | 8192x  512 | 0.0008764| 0.0017205| 0.0216015|1.96x | 24.65x |
   |16384x  512 | 0.0057181| 0.0072662| 0.0464290|1.27x |  8.12x |
   |    1x  768 | 0.0000274| 0.0000309| 0.0000519|1.13x |  1.89x |
   |  128x  768 | 0.0000439| 0.0000675| 0.0005498|1.54x | 12.52x |
   | 2560x  768 | 0.0003469| 0.0007757| 0.0101437|2.24x | 29.24x |
   | 4096x  768 | 0.0005857| 0.0013381| 0.0161946|2.28x | 27.65x |
   | 8192x  768 | 0.0014930| 0.0026524| 0.0322792|1.78x | 21.62x |
   |16384x  768 | 0.0088047| 0.0110582| 0.0698267|1.26x |  7.93x |
   |    1x 1024 | 0.0000275| 0.0000330| 0.0000573|1.20x |  2.08x |
   |  128x 1024 | 0.0000486| 0.0000790| 0.0007189|1.63x | 14.79x |
   | 2560x 1024 | 0.0004582| 0.0010214| 0.0135037|2.23x | 29.47x |
   | 4096x 1024 | 0.0008070| 0.0017359| 0.0215496|2.15x | 26.70x |
   | 8192x 1024 | 0.0057007| 0.0073134| 0.0463280|1.28x |  8.13x |
   |16384x 1024 | 0.0116098| 0.0147560| 0.0935520|1.27x |  8.06x |
   
   
   ### AWS Sockeye ###
   Observed a 14% speed up in end-to-end machine translation with Sockeye.  
Sockeye 2.2 (29795b82) on a c5.12xlarge with `export OMP_NUM_THREADS=4` 
translating a test set.  
   
   Compiled on Ubuntu 18 with `cmake -DCMAKE_BUILD_TYPE=Release -DUSE_MKLDNN=ON 
-DUSE_CUDA=OFF -DUSE_TVM_OP=OFF -DUSE_MKL_IF_AVAILABLE=OFF 
-DCMAKE_C_COMPILER=gcc-8 -DCMAKE_CXX_COMPILER=g++-8 -GNinja ..`  Note: no MKL.  
   
   Before
   ```
   [INFO:__main__] Processed 2964 lines. Total time: 133.3097, sec/sent: 
0.0450, sent/sec: 22.2339
   
   real 2m15.716s
   user 9m52.988s
   sys  0m13.504s
   ```
   
   After
   ```
   [INFO:__main__] Processed 2964 lines. Total time: 116.6679, sec/sent: 
0.0394, sent/sec: 25.4054
   
   real 1m58.858s
   user 8m45.803s
   sys  0m13.823s
   ```
   
   The above runs were done as normal, without the profiler.  I then turned the 
profiler on.  We can see that LayerNorm is consuming a substantial amount of 
time:
   Before
   ```
   operator
   =================
   Name                          Total Count        Time (ms)    Min Time (ms)  
  Max Time (ms)    Avg Time (ms)
   ----                          -----------        ---------    -------------  
  -------------    -------------
   _contrib_intgemm_fully_connected          822520       26357.8887           
0.0090           0.3390           0.0320
   LayerNorm                          459522       20225.8086           0.0230  
         0.4860           0.0440
   elemwise_add                       601340        7813.2148           0.0040  
         0.1970           0.0130
   _contrib_interleaved_matmul_encdec_qk          155884        7557.1152       
    0.0050           0.3560           0.0485
   _contrib_interleaved_matmul_encdec_valatt          155884        6168.3472   
        0.0040           0.4120           0.0396
   FullyConnected                      48262        4070.1250           0.0260  
         4.7480           0.0843
   DeleteVariable                    1577462        3830.7241           0.0000  
         0.3660           0.0024
   Concat                             107622        3493.2380           0.0100  
         0.2970           0.0325
   take                               386096        3484.5449           0.0020  
         1.5600           0.0090
   SliceChannel                        65296        3468.1431           0.0060  
         0.4370           0.0531
   where                              144786        3203.5801           0.0030  
         0.2090           0.0221
   Activation                         252408        3095.2820           0.0060  
         0.1750           0.0123
   
   ```
   After
   ```
   operator
   =================
   Name                          Total Count        Time (ms)    Min Time (ms)  
  Max Time (ms)    Avg Time (ms)
   ----                          -----------        ---------    -------------  
  -------------    -------------
   _contrib_intgemm_fully_connected          822316       25351.8438           
0.0090           0.4190           0.0308
   elemwise_add                       601170        8229.7861           0.0040  
         0.1650           0.0137
   _contrib_interleaved_matmul_encdec_qk          155850        7577.9399       
    0.0050           0.4030           0.0486
   _contrib_interleaved_matmul_encdec_valatt          155850        6169.1318   
        0.0040           0.4310           0.0396
   FullyConnected                      48245        4170.0972           0.0240  
         4.8480           0.0864
   DeleteVariable                    1576986        3935.9939           0.0000  
         0.3490           0.0025
   take                               385960        3624.0161           0.0020  
         2.4180           0.0094
   Concat                             107605        3561.9041           0.0100  
         0.3540           0.0331
   SliceChannel                        65296        3475.8010           0.0060  
         0.5690           0.0532
   where                              144735        3241.1169           0.0030  
         0.2080           0.0224
   Activation                         252340        2855.7710           0.0050  
         0.2440           0.0113
   LayerNorm                          459403        2791.0029           0.0040  
         0.0540           0.0061
   ```
   The new implementation is 7.21x as fast on average according to the 
profiler.  
   
   The number of LayerNorm invocations changes 0.02% because beam search 
iterations are impacted by tie breaking.  
   
   ### Unit test ###
   Before: 62.210s
   After: 61.321s
   
   But note unit tests spend most of their time comparing things rather than 
running the kernels.  
   
   ## Comments ##
   - LayerNorm is just one of those kernels that changes slightly with any 
implementation, so outputs that depend on near-ties will change.  
   - float16 support on CPU accumulates in float32.  Since float16 only exists 
in conversion on CPU, this is faster anyway.  Also, there wasn't an OMP 
reduction for float16.  
   - There is no threaded parallelization within a channel (i.e. to sum).  I 
doubt existing channel sizes justify this given the cost of threading on CPUs.  


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to