barry-jin opened a new pull request #20087:
URL: https://github.com/apache/incubator-mxnet/pull/20087


   ## Description ##
   Adopt packedfunc based ffi on some frequently used numpy_extension 
operators. Some benchmarks are as follows:
   npx.softmax and npx.log_softmax
   ```python
   setup = """
   from mxnet import np, npx
   npx.set_np()
   a = np.ones((2, 2))
   """
   stmt = """
   npx.softmax(a)
   """
   legacy
   7.195554300000006e-05
   New FFI
   5.542412399999996e-05
   
   
   setup = """
   from mxnet import np, npx
   npx.set_np()
   a = np.array([1, 2, .1])
   """
   stmt = """
   npx.log_softmax(a)
   """
   legacy
   7.109741400000047e-05
   New FFI
   5.514604400000067e-05
   ```
   
   npx.activation
   ```python
   setup = """
   from mxnet import np, npx, img
   from mxnet.ndarray.numpy import _internal as _npi
   from mxnet.ndarray.numpy import _api_internal
   npx.set_np()
   x = np.arange(8).reshape((2, 2, 2))
   """
   stmt = """
   o = npx.activation(x, act_type='sigmoid')
   """
   timer = timeit.Timer(setup=setup,
                        stmt=stmt)
   num_repeat = 1000
   print(min(timer.repeat(repeat=10, number=num_repeat)) / num_repeat)
   
   legacy
   0.00010351458600000019
   New FFI
   4.998078599999989e-05
   ```
   
   npx.batch_norm
   ```python
   setup = """
   from mxnet import np, npx, img
   from mxnet.ndarray.numpy import _internal as _npi
   from mxnet.ndarray.numpy import _api_internal
   npx.set_np()
   a = np.zeros((2,2))
   b = np.ones((2,2))
   c = np.ones((2,2))
   d = np.ones((2,2))
   e = np.ones((2,2))
   """
   stmt = """
   o = npx.batch_norm(a, b, c, d, e,
                      eps=1e-3,
                      momentum=0.9,
                      fix_gamma=True,
                      use_global_stats=False,
                      output_mean_var=False,
                      axis=1,
                      cudnn_off=False,
                      min_calib_range=None,
                      max_calib_range=None)
   """
   
   timer = timeit.Timer(setup=setup,
                        stmt=stmt)
   num_repeat = 1000
   print(min(timer.repeat(repeat=10, number=num_repeat)) / num_repeat)
   
   legacy
   0.00017782039200000012
   New FFI
   8.22070250000002e-05
   ```
   
   npx.fully_connected
   ```python
   setup = """
   from mxnet import np, npx, img
   from mxnet.ndarray.numpy import _internal as _npi
   from mxnet.ndarray.numpy import _api_internal
   npx.set_np()
   data = np.arange(8).reshape((4,2))
   weight = np.arange(4).reshape((2,2))
   bias = np.arange(2).reshape((2,))
   num_hidden = bias.shape[0]
   """
   stmt = """
   o = npx.fully_connected(data, weight, bias, num_hidden=num_hidden, 
no_bias=False)
   """
   timer = timeit.Timer(setup=setup,
                        stmt=stmt)
   num_repeat = 1000
   print(min(timer.repeat(repeat=10, number=num_repeat)) / num_repeat)
   
   legacy
   9.581770600000006e-05
   New FFI
   6.12347289999997e-05
   ```
   ## Checklist ##
   ### Essentials ###
   - [x] PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], 
[FEATURE], [DOC], etc)
   - [x] Changes are complete (i.e. I finished coding on this PR)
   - [x] All changes have test coverage
   - [x] Code is well-documented
   
   ### Changes ###
   - [ ] Feature1, tests, (and when applicable, API doc)
   - [ ] Feature2, tests, (and when applicable, API doc)
   
   ## Comments ##
   - If this change is a backward incompatible change, why must this change be 
made.
   - Interesting edge cases to note here
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to