sxjscience opened a new issue #18043: [Performance][Numpy] np.einsum can be 500 
- 1000 times slower than torch.einsum
URL: https://github.com/apache/incubator-mxnet/issues/18043
 
 
   The performance of `np.einsum` in GPU is not very good and will usually be 
500 times slower than `th.einsum`. Because `einsum` is essential for 
implementing the attention mechanism used in NLP + CV, we should accelerate the 
implementation.
   
   Here is the code to profile different implementations of einsum (also in 
gist: https://gist.github.com/sxjscience/bfda1a8bd2942d93eef5ddf8a15b52b8). The 
profiling result shows that the following order
   
   **PyTorch einsum > MXNet no-einsum >> MXNet  einsum**
   
   ```python
   
   import mxnet as mx
   import numpy as np
   import torch as th
   import argparse
   mx.npx.set_np()
   
   parser = argparse.ArgumentParser(description='Profile einsum')
   parser.add_argument('--mode', choices=['einsum', 'no_einsum', 'th_einsum'],
                       default='einsum', required=True)
   parser.add_argument('--problem', type=int,
                       choices=[0, 1, 2], help='Problem type.', default=0, 
required=True)
   args = parser.parse_args()
   
   np.random.seed(100)
   batch_size = 64
   num_heads = 8
   seq_length_A = 100
   seq_length_B = 50
   units = 128
   
   if args.problem == 0:
       lhs = np.random.normal(0, 1, (batch_size, num_heads, seq_length_A, 
units))
       rhs = np.random.normal(0, 1, (batch_size, num_heads, seq_length_B, 
units))
       mx_lhs = mx.np.array(lhs, dtype=np.float32, ctx=mx.gpu())
       mx_rhs = mx.np.array(rhs, dtype=np.float32, ctx=mx.gpu())
       mx.npx.waitall()
       th_lhs = th.from_numpy(lhs).float().cuda()
       th_rhs = th.from_numpy(rhs).float().cuda()
       typ = 'bnid,bnjd->bnij'
       if args.mode == 'einsum':
           out = mx.np.einsum(typ, mx_lhs, mx_rhs)
           out_np = out.asnumpy()
       elif args.mode == 'no_einsum':
           out = mx.npx.batch_dot(mx_lhs, mx_rhs, transpose_b=True)
           out_np = out.asnumpy()
       elif args.mode == 'th_einsum':
           out = th.einsum(typ, th_lhs, th_rhs)
           out_np = out.cpu().numpy()
       else:
           raise NotImplementedError
       print(out_np.shape)
   elif args.problem == 1:
       lhs = np.random.normal(0, 1, (batch_size, seq_length_A, num_heads, 
units))
       rhs = np.random.normal(0, 1, (batch_size, seq_length_B, num_heads, 
units))
       mx_lhs = mx.np.array(lhs, dtype=np.float32, ctx=mx.gpu())
       mx_rhs = mx.np.array(rhs, dtype=np.float32, ctx=mx.gpu())
       mx.npx.waitall()
       th_lhs = th.from_numpy(lhs).float().cuda()
       th_rhs = th.from_numpy(rhs).float().cuda()
       typ = 'bind,bjnd->bnij'
       if args.mode == 'einsum':
           out = mx.np.einsum(typ, mx_lhs, mx_rhs)
           out_np = out.asnumpy()
       elif args.mode == 'no_einsum':
           out = mx.npx.batch_dot(mx.np.swapaxes(mx_lhs, 1, 2),
                                  mx.np.swapaxes(mx_rhs, 1, 2),
                                  transpose_b=True)
           out_np = out.asnumpy()
       elif args.mode == 'th_einsum':
           out = th.einsum(typ, th_lhs, th_rhs)
           out_np = out.cpu().numpy()
       else:
           raise NotImplementedError
       print(out_np.shape)
   elif args.problem == 2:
       lhs = np.random.normal(0, 1, (batch_size, seq_length_A, num_heads, 
units))
       rhs = np.random.normal(0, 1, (seq_length_B, num_heads, units))
       mx_lhs = mx.np.array(lhs, dtype=np.float32, ctx=mx.gpu())
       mx_rhs = mx.np.array(rhs, dtype=np.float32, ctx=mx.gpu())
       mx.npx.waitall()
       th_lhs = th.from_numpy(lhs).float().cuda()
       th_rhs = th.from_numpy(rhs).float().cuda()
       typ = 'bind,jnd->bnij'
       if args.mode == 'einsum':
           out = mx.np.einsum(typ, mx_lhs, mx_rhs)
           out_np = out.asnumpy()
       elif args.mode == 'no_einsum':
           out = mx.np.matmul(mx.np.swapaxes(mx_lhs, 1, 2),
                              mx.np.transpose(mx_rhs, (1, 2, 0)))
           out_np = out.asnumpy()
       elif args.mode == 'th_einsum':
           out = th.einsum(typ, th_lhs, th_rhs)
           out_np = out.cpu().numpy()
       else:
           raise NotImplementedError
       print(out_np.shape)
   
   ```
   
   We profiled three different usages of einsum:
   
   1. (B, K, T0, C) X (B, K, T1, C) --> (B, K, T0, T1)
      - MXNet einsum
         `nvprof python profile_einsum.py --mode einsum --problem 0`
   
         | Time | Kernel |
         | ----- | -------|
         | 41.009ms | 
\_ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_12numpy_einsumILi5ELi1ELb0EdEEJPfNS_6common11StaticArrayIS5_Li16EEEN7mshadow5ShapeILi5EEENS7_ISB_Li16EEESB_SC_iiS5_EEEviDpT0\_|
      - MXNet implementation without einsum
         `nvprof python profile_einsum.py --mode no_einsum --problem 0`
   
         | Time | Kernel |
         | ----- | -------|
         | 198.75us | volta_sgemm_128x64_tn |
      - PyTorch Implementation
         `nvprof python profile_einsum.py --mode th_einsum --problem 0`
   
         | Time | Kernel |
         | ----- | -------|
         | 192.35us | volta_sgemm_128x64_tn |
   
   2. (B, T0, K, C) X (B, T1, K, C) --> (B, K, T0, T1)
      - MXNet einsum
         `nvprof python profile_einsum.py --mode einsum --problem 1`
   
         | Time | Kernel |
         | ----- | -------|
         | 40.665ms | 
\_ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_12numpy_einsumILi5ELi1ELb0EdEEJPfNS_6common11StaticArrayIS5_Li16EEEN7mshadow5ShapeILi5EEENS7_ISB_Li16EEESB_SC_iiS5_EEEviDpT0\_|
      - MXNet implementation without einsum
         `nvprof python profile_einsum.py --mode no_einsum --problem 1`
   
         | Time | Kernel |
         | ----- | -------|
         | 185.76us | volta_sgemm_128x64_tn |
         | 89.519us | void mshadow::cuda::MapPlanKernel<mshadow::sv::saveto, 
int=8, mshadow::expr::Plan<mshadow::Tensor<mshadow::gpu, int=5, float>, float>, 
mshadow::expr::Plan<mshadow::expr::SwapAxisExp<mshadow::Tensor<mshadow::gpu, 
int=5, float>, float, int=5, int=2, int=1>, float>>(mshadow::gpu, int, 
mshadow::Shape<int=2>, int=5) |
      - PyTorch implementation
         `nvprof python profile_einsum.py --mode th_einsum --problem 1`
   
         | Time | Kernel |
         | ----- | -------|
         | 193.02us | volta_sgemm_128x64_tn |
         | 61.967us | 
\_ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_21copy_device_to_deviceERNS_14TensorIteratorEbENKUlvE_clEvENKUlvE2_clEvEUlfE_EEvS4_RKT_EUliE2_EEviT1\_
 |
   
   
   3. (B, K, T0, C) X (T1, K, C) --> (B, K, T0, T1)
      - MXNet einsum
         `nvprof python profile_einsum.py --mode einsum --problem 2`
   
         | Time | Kernel |
         | ----- | -------|
         | 40.551ms | 
\_ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_12numpy_einsumILi5ELi1ELb0EdEEJPfNS_6common11StaticArrayIS5_Li16EEEN7mshadow5ShapeILi5EEENS7_ISB_Li16EEESB_SC_iiS5_EEEviDpT0\_|
      - MXNet implementation without einsum
         `nvprof python profile_einsum.py --mode no_einsum --problem 2`
   
         | Time | Kernel |
         | ----- | -------|
         | 322.33us | 
\_ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_16broadcast_kernelINS0_10mshadow_op8identityEEEJPfS7_N7mshadow5ShapeILi5EEESA_NS_9OpReqTypeEmEEEviDpT0\_|
         | 183.23us | volta_sgemm_128x64_nn |
         | 120.13us | void mshadow::cuda::MapPlanKernel<mshadow::sv::saveto, 
int=8, mshadow::expr::Plan<mshadow::Tensor<mshadow::gpu, int=5, float>, float>, 
mshadow::expr::Plan<mshadow::expr::SwapAxisExp<mshadow::Tensor<mshadow::gpu, 
int=5, float>, float, int=5, int=2, int=1>, float>>(mshadow::gpu, int, 
mshadow::Shape<int=2>, int=5) |
         | 5.3120us | void mxnet::op::cuda::transpose_pseudo2D<float, unsigned 
long, bool=0>(float*, float, int, int, int, int) |
      - PyTorch Implementation
         `nvprof python profile_einsum.py --mode th_einsum --problem 2`
   
         | Time | Kernel |
         | ----- | -------|
         | 152.16us | volta_sgemm_128x64_tn |
         | 28.704us | 
\_ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_21copy_device_to_deviceERNS_14TensorIteratorEbENKUlvE_clEvENKUlvE2_clEvEUlfE_EEvS4_RKT_EUliE2_EEviT1\_
 |
   
   @yzhliu @hzfan @haojin2 @reminisce @szha 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to