sxjscience opened a new issue #18043: [Performance][Numpy] np.einsum can be 500 - 1000 times slower than torch.einsum URL: https://github.com/apache/incubator-mxnet/issues/18043 The performance of `np.einsum` in GPU is not very good and will usually be 500 times slower than `th.einsum`. Because `einsum` is essential for implementing the attention mechanism used in NLP + CV, we should accelerate the implementation. Here is the code to profile different implementations of einsum (also in gist: https://gist.github.com/sxjscience/bfda1a8bd2942d93eef5ddf8a15b52b8). The profiling result shows that the following order **PyTorch einsum > MXNet no-einsum >> MXNet einsum** ```python import mxnet as mx import numpy as np import torch as th import argparse mx.npx.set_np() parser = argparse.ArgumentParser(description='Profile einsum') parser.add_argument('--mode', choices=['einsum', 'no_einsum', 'th_einsum'], default='einsum', required=True) parser.add_argument('--problem', type=int, choices=[0, 1, 2], help='Problem type.', default=0, required=True) args = parser.parse_args() np.random.seed(100) batch_size = 64 num_heads = 8 seq_length_A = 100 seq_length_B = 50 units = 128 if args.problem == 0: lhs = np.random.normal(0, 1, (batch_size, num_heads, seq_length_A, units)) rhs = np.random.normal(0, 1, (batch_size, num_heads, seq_length_B, units)) mx_lhs = mx.np.array(lhs, dtype=np.float32, ctx=mx.gpu()) mx_rhs = mx.np.array(rhs, dtype=np.float32, ctx=mx.gpu()) mx.npx.waitall() th_lhs = th.from_numpy(lhs).float().cuda() th_rhs = th.from_numpy(rhs).float().cuda() typ = 'bnid,bnjd->bnij' if args.mode == 'einsum': out = mx.np.einsum(typ, mx_lhs, mx_rhs) out_np = out.asnumpy() elif args.mode == 'no_einsum': out = mx.npx.batch_dot(mx_lhs, mx_rhs, transpose_b=True) out_np = out.asnumpy() elif args.mode == 'th_einsum': out = th.einsum(typ, th_lhs, th_rhs) out_np = out.cpu().numpy() else: raise NotImplementedError print(out_np.shape) elif args.problem == 1: lhs = np.random.normal(0, 1, (batch_size, seq_length_A, num_heads, units)) rhs = np.random.normal(0, 1, (batch_size, seq_length_B, num_heads, units)) mx_lhs = mx.np.array(lhs, dtype=np.float32, ctx=mx.gpu()) mx_rhs = mx.np.array(rhs, dtype=np.float32, ctx=mx.gpu()) mx.npx.waitall() th_lhs = th.from_numpy(lhs).float().cuda() th_rhs = th.from_numpy(rhs).float().cuda() typ = 'bind,bjnd->bnij' if args.mode == 'einsum': out = mx.np.einsum(typ, mx_lhs, mx_rhs) out_np = out.asnumpy() elif args.mode == 'no_einsum': out = mx.npx.batch_dot(mx.np.swapaxes(mx_lhs, 1, 2), mx.np.swapaxes(mx_rhs, 1, 2), transpose_b=True) out_np = out.asnumpy() elif args.mode == 'th_einsum': out = th.einsum(typ, th_lhs, th_rhs) out_np = out.cpu().numpy() else: raise NotImplementedError print(out_np.shape) elif args.problem == 2: lhs = np.random.normal(0, 1, (batch_size, seq_length_A, num_heads, units)) rhs = np.random.normal(0, 1, (seq_length_B, num_heads, units)) mx_lhs = mx.np.array(lhs, dtype=np.float32, ctx=mx.gpu()) mx_rhs = mx.np.array(rhs, dtype=np.float32, ctx=mx.gpu()) mx.npx.waitall() th_lhs = th.from_numpy(lhs).float().cuda() th_rhs = th.from_numpy(rhs).float().cuda() typ = 'bind,jnd->bnij' if args.mode == 'einsum': out = mx.np.einsum(typ, mx_lhs, mx_rhs) out_np = out.asnumpy() elif args.mode == 'no_einsum': out = mx.np.matmul(mx.np.swapaxes(mx_lhs, 1, 2), mx.np.transpose(mx_rhs, (1, 2, 0))) out_np = out.asnumpy() elif args.mode == 'th_einsum': out = th.einsum(typ, th_lhs, th_rhs) out_np = out.cpu().numpy() else: raise NotImplementedError print(out_np.shape) ``` We profiled three different usages of einsum: 1. (B, K, T0, C) X (B, K, T1, C) --> (B, K, T0, T1) - MXNet einsum `nvprof python profile_einsum.py --mode einsum --problem 0` | Time | Kernel | | ----- | -------| | 41.009ms | \_ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_12numpy_einsumILi5ELi1ELb0EdEEJPfNS_6common11StaticArrayIS5_Li16EEEN7mshadow5ShapeILi5EEENS7_ISB_Li16EEESB_SC_iiS5_EEEviDpT0\_| - MXNet implementation without einsum `nvprof python profile_einsum.py --mode no_einsum --problem 0` | Time | Kernel | | ----- | -------| | 198.75us | volta_sgemm_128x64_tn | - PyTorch Implementation `nvprof python profile_einsum.py --mode th_einsum --problem 0` | Time | Kernel | | ----- | -------| | 192.35us | volta_sgemm_128x64_tn | 2. (B, T0, K, C) X (B, T1, K, C) --> (B, K, T0, T1) - MXNet einsum `nvprof python profile_einsum.py --mode einsum --problem 1` | Time | Kernel | | ----- | -------| | 40.665ms | \_ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_12numpy_einsumILi5ELi1ELb0EdEEJPfNS_6common11StaticArrayIS5_Li16EEEN7mshadow5ShapeILi5EEENS7_ISB_Li16EEESB_SC_iiS5_EEEviDpT0\_| - MXNet implementation without einsum `nvprof python profile_einsum.py --mode no_einsum --problem 1` | Time | Kernel | | ----- | -------| | 185.76us | volta_sgemm_128x64_tn | | 89.519us | void mshadow::cuda::MapPlanKernel<mshadow::sv::saveto, int=8, mshadow::expr::Plan<mshadow::Tensor<mshadow::gpu, int=5, float>, float>, mshadow::expr::Plan<mshadow::expr::SwapAxisExp<mshadow::Tensor<mshadow::gpu, int=5, float>, float, int=5, int=2, int=1>, float>>(mshadow::gpu, int, mshadow::Shape<int=2>, int=5) | - PyTorch implementation `nvprof python profile_einsum.py --mode th_einsum --problem 1` | Time | Kernel | | ----- | -------| | 193.02us | volta_sgemm_128x64_tn | | 61.967us | \_ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_21copy_device_to_deviceERNS_14TensorIteratorEbENKUlvE_clEvENKUlvE2_clEvEUlfE_EEvS4_RKT_EUliE2_EEviT1\_ | 3. (B, K, T0, C) X (T1, K, C) --> (B, K, T0, T1) - MXNet einsum `nvprof python profile_einsum.py --mode einsum --problem 2` | Time | Kernel | | ----- | -------| | 40.551ms | \_ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_12numpy_einsumILi5ELi1ELb0EdEEJPfNS_6common11StaticArrayIS5_Li16EEEN7mshadow5ShapeILi5EEENS7_ISB_Li16EEESB_SC_iiS5_EEEviDpT0\_| - MXNet implementation without einsum `nvprof python profile_einsum.py --mode no_einsum --problem 2` | Time | Kernel | | ----- | -------| | 322.33us | \_ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_16broadcast_kernelINS0_10mshadow_op8identityEEEJPfS7_N7mshadow5ShapeILi5EEESA_NS_9OpReqTypeEmEEEviDpT0\_| | 183.23us | volta_sgemm_128x64_nn | | 120.13us | void mshadow::cuda::MapPlanKernel<mshadow::sv::saveto, int=8, mshadow::expr::Plan<mshadow::Tensor<mshadow::gpu, int=5, float>, float>, mshadow::expr::Plan<mshadow::expr::SwapAxisExp<mshadow::Tensor<mshadow::gpu, int=5, float>, float, int=5, int=2, int=1>, float>>(mshadow::gpu, int, mshadow::Shape<int=2>, int=5) | | 5.3120us | void mxnet::op::cuda::transpose_pseudo2D<float, unsigned long, bool=0>(float*, float, int, int, int, int) | - PyTorch Implementation `nvprof python profile_einsum.py --mode th_einsum --problem 2` | Time | Kernel | | ----- | -------| | 152.16us | volta_sgemm_128x64_tn | | 28.704us | \_ZN2at6native18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_21copy_device_to_deviceERNS_14TensorIteratorEbENKUlvE_clEvENKUlvE2_clEvEUlfE_EEvS4_RKT_EUliE2_EEviT1\_ | @yzhliu @hzfan @haojin2 @reminisce @szha
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
