stu1130 commented on issue #14725: Performance Regression on CUDA10 URL: https://github.com/apache/incubator-mxnet/issues/14725#issuecomment-485599298 Update for what @anirudh2290 & I have done so far 1. Based on MXNet profiler result(I have updated on the description), we found _backward_FullyConnected increase from 0.0838 to 0.1352. 2. We added following piece of code below https://github.com/apache/incubator-mxnet/blob/68efc1598240bbe36b91d6660489519431795a5d/src/operator/nn/fully_connected-inl.h#L168 ``` LOG(INFO) << grad.shape_ << " " << data.shape_ << " " << gwmat.shape_; ``` 3. I wrote a python script to collect the grad shape, data shape and gwmat shape. ``` # grad shape/data shape/gwmat shape count for 1 epoch ('640,10000', '640,650', '10000,650'): 465 ('32,2600', '32,650', '2600,650'): 138640 ('960,10000', '960,650', '10000,650'): 440 ('1600,10000', '1600,650', '10000,650'): 48 ('1280,10000', '1280,650', '10000,650'): 192 ('320,10000', '320,650', '10000,650'): 154 ('1920,10000', '1920,650', '10000,650'): 9 ``` 4. @anirudh2290 wrote a script to reproduce it ``` import time import random import mxnet as mx import numpy as np mx.random.seed(1234) np.random.seed(1234) random.seed(1234) a = mx.nd.random.uniform(shape=(32, 650), ctx=mx.gpu(0)) b = mx.nd.random.uniform(shape=(2600, 650), ctx=mx.gpu(0)) c = mx.nd.random.uniform(shape=(2600,), ctx=mx.gpu(0)) data = mx.sym.var("data") weight = mx.sym.var("weight") bias = mx.sym.var("bias") out = mx.sym.FullyConnected(data, weight, bias, flatten=True, no_bias=False, num_hidden=2600) out = mx.sym.make_loss(out) ex = out.bind(args={"data": a, "weight": b, "bias": c}, args_grad={"data": mx.nd.zeros(a.shape, ctx=mx.gpu(0)), "weight": mx.nd.zeros(b.shape, ctx=mx.gpu(0)), "bias": mx.nd.zeros(c.shape, ctx=mx.gpu(0))}, ctx=mx.gpu(0)) mx.nd.waitall() total_cost = 0 num = 100 for i in range(num): if i < 5: continue start = time.time() ex.forward(is_train=True, data=a, weight=b, bias=c) ex.backward() ex.outputs[0].asnumpy() ex.grad_arrays[0].asnumpy() mx.nd.waitall() end = time.time() - start total_cost += end print("Average: {}".format(total_cost / num)) ``` And we got the result ``` ('640,10000', '640,650', '10000,650') CUDA 10: Average: 0.00515265703201294 CUDA 9.2: Average: 0.005957033634185791 ('32,2600', '32,650', '2600,650') CUDA 10: Average: 0.0008404040336608886 CUDA 9.2: Average: 0.0008698630332946777 ('960,10000', '960,650', '10000,650') CUDA 10: Average: 0.02021113634109497 CUDA 9.2: Average: 0.01884469509124756 (‘1600,10000’, ‘1600,650’, ‘10000,650’) CUDA 10: Average: 0.030648748874664306 CUDA 9.2: Average: 0.02739755630493164 (‘1280,10000’, ‘1280,650’, ‘10000,650’) CUDA 10: Average: 0.023832192420959474 CUDA 9.2: Average: 0.02541511058807373 (‘320,10000’, ‘320,650’, ‘10000,650’) CUDA 10: Average: 0.003631892204284668 CUDA 9.2: Average: 0.0030696773529052734 ('1920,10000', '1920,650', '10000,650') CUDA 10: Average: 0.03472873210906982 CUDA 9.2: Average: 0.03499695301055908 ```
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
