RuRo opened a new pull request #17734: [MXNET-889] Implement ONNX export for 
gluon LSTM.
URL: https://github.com/apache/incubator-mxnet/pull/17734
 
 
   ## Description ##
   Implements 3 new node conversion functions, to facilitate the conversion of 
`gluon.rnn.LSTM` blocks to ONNX.
   
   1) `_zeros` (as well as `_ones` and `_full` as a bonus)
       `gluon.rnn.LSTM` uses `sym.zeros` to [initialize the initial recurrent 
states](https://github.com/apache/incubator-mxnet/blob/10a12d59f67ad21032a94b1721aaf9b96fddac85/python/mxnet/gluon/rnn/rnn_layer.py#L234).
   2) `_rnn_param_concat` which is used by mxnet to [concatenate all the 
flattened 
parameters](https://github.com/apache/incubator-mxnet/blob/10a12d59f67ad21032a94b1721aaf9b96fddac85/python/mxnet/gluon/rnn/rnn_layer.py#L278),
 before passing them to the `RNN` node.
   3) `RNN` which is the backend op for the actual `LSTM` computation.
   
   I've also implemented/updated a couple helper functions in the 
[`_op_translations.py`](https://github.com/RuRo/incubator-mxnet/blob/master/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py)
 file in a separate commit to hopefully reduce the boilerplate amount.
   
   ## Checklist ##
   ### Essentials ###
   Please feel free to remove inapplicable items for your PR.
   - [x] The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to 
the relevant [JIRA issue](https://issues.apache.org/jira/projects/MXNET/issues) 
created (except PRs with tiny changes)
   - [x] Changes are complete (i.e. I finished coding on this PR)
   - [ ] All changes have test coverage:
   - Unit tests are added for small changes to verify correctness (e.g. adding 
a new operator)
   - Nightly tests are added for complicated/long-running ones (e.g. changing 
distributed kvstore)
   - Build tests will be added for build configuration changes (e.g. adding a 
new build option with NCCL)
   - [x] Code is well-documented: 
   - For user-facing API changes, API doc string has been updated. 
   - For new C++ functions in header files, their functionalities and arguments 
are documented. 
   - For new examples, README.md is added to explain the what the example does, 
the source of the dataset, expected performance on test set and reference to 
the original paper if applicable
   - Check the API doc at 
https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
   - [x] To the best of my knowledge, examples are either not affected by this 
change, or have been fixed to be compatible with this change
   
   ## Tests ##
   
   I didn't add any tests for this node yet, since I am not 100% sure, how 
exactly to do that.
   Some other PRs with ONNX exports, just add the new node names to a list in 
[`tests/python-pytest/onnx/test_cases.py`](https://github.com/apache/incubator-mxnet/blob/10a12d59f67ad21032a94b1721aaf9b96fddac85/tests/python-pytest/onnx/test_cases.py#L96),
 however, I don't see, how would this add any actual tests, without a reference 
ONNX implementation available.
   
   I would gladly add some unittests, if someone can point the right direction 
though. For now, I've verified the correctness of my conversion by
   1) Running this script:
   
       <details>
       <summary>test_rnn_export_onnx.py</summary>
   
       ```python
       #!/bin/env python3
       import numpy as np
       import mxnet as mx
       import onnxruntime as rt
   
       import logging
       import shutil
       import os
       logging.getLogger().setLevel(logging.INFO)
   
   
       # Configuration
       path = 'lstm-model'
       seq_len = 10
       batch_size = 32
       input_size = 16
       hidden_size = 8
       bidirectional = True
   
   
       # Create input data
       shape = (seq_len, batch_size, input_size)
       I = mx.nd.random.randn(*shape)
   
   
       # Create LSTM network
       seq = mx.gluon.nn.HybridSequential()
       with seq.name_scope():
           seq.add(mx.gluon.rnn.LSTM(hidden_size, bidirectional=bidirectional))
       seq.initialize()
       seq(I)
   
   
       # Init all parameters with random values
       for k, p in seq.collect_params().items():
           p.data()[...] = np.random.randn(*p.data().shape)
       seq.hybridize(static_alloc=True, static_shape=True)
       res_mx = seq(I)
       res_mx = res_mx.asnumpy()
   
   
       # Export to symbol and then to ONNX and forward with onnxruntime
       seq.export(path)
       mx.contrib.onnx.export_model(
           f'{path}-symbol.json',
           f'{path}-0000.params',
           [shape], np.float32,
           onnx_file_path=f'{path}.onnx',
           verbose=True
       )
   
   
       # Forward, using onnxruntime
       sess = rt.InferenceSession(f'{path}.onnx')
       res_onnx, = sess.run(None, {'data': I.asnumpy()})
   
   
       # Check results
       diff = np.abs(res_mx - res_onnx)
       print(f"{diff.mean()=}")
       print(f"{diff.max()=}")
       np.testing.assert_almost_equal(res_mx, res_onnx, decimal=4)
       # decimal=4 is the same as mxnet unittests for other ONNX ops
   
   
       # Clean up files
       os.remove(f"{path}-symbol.json")
       os.remove(f"{path}-0000.params")
       os.remove(f"{path}.onnx")
       ```
   
       </details>
   
       Which generates a model with an `LSTM` node, fills all the parameters 
with random data, exports it to a `symbol.json` and `0000.params` and then 
converts the model to `ONNX`. It then loads the model using `onnxruntime` and 
forwards some random data through both the `mxnet` as well as the `onnx` models 
and asserts, that the results are close.
   
   2) Using this conversion on a CRNN network, I trained at my job and 
verifying that the prediction results are the same on a subset of out test set. 
(unfortunately, I can't give too much details about this model, so you'll have 
to trust me here)
   
   ## Comments ##
   There are a few details in my current implementation that should be noted:
   1) `_zeros`
       - It seems, that `mx.sym` can broadcast a `_zeros/_ones/_full` node 
along an axis with dimension 0 as if it were a 1, while `ONNX` can't do that.
           ```python
           import mxnet as mx
           a = mx.sym.zeros((10, 0, 30))
           b = mx.sym.Variable('b')
           (a * b).eval(b=mx.nd.zeros((10, 20, 30)))[0].shape
           ```
           Currently, my implementation converts all the 0 dims to 1s. Is this 
correct?
       - The version of `ONNX`, that mxnet uses doesn't support some newer 
nodes like `ConstantOfShape` and I was not able to find a working onnx 
implementation, which actually supported the currently deprecated 
`ConstantFill` op, so I had to actually store the `_zeros` data in a tensor, 
which ends up being stored in the onnx model file.<br/>
       This also applies to the helper funtions and the other nodes. I've had 
to settle for some "less than elegant" solutions in some cases, because of the 
lacking opset support. I might have also missed a better way to do some 
operations, so if you have any better alternatives, I would gladly change my 
implementation.<br/>
   
   2) `_rnn_param_concat`
       You can read more about my implementation in the code (comments), but 
the general outline is that since `ONNX` expects the `RNN` parameters prepared 
in a completely different way compared to mxnet, I've had to employ some "dark" 
magic, involving `convert_rnn_param_concat` generating multiple nodes and 
`convert_RNN` 'guessing' these node names, using regular expressions.<br/>
       As a result, the current conversion mechanism only works, if the 
parameters of the RNN block are concatenated using `_rnn_param_concat` (and not 
for example `Concat`). If the conversion of of `mx.sym.RNN` (without `gluon`) 
is a widely requested feature, we would have to convert the `_rnn_param_concat` 
node to a regular `Concat`, and then split the parameters back apart in the 
`RNN` conversion, which is not very nice IMO.<br/>
   
   3) `RNN`
       Even though, this PR implements the conversion of the `RNN` node, 
currently only the `LSTM` mode is supported and many of the "customization" 
options are also not implemented.
       **Implemented:**
       - `mode=LSTM`
       - `layout='TNC'`
       - `bidirectional=True` and `False`
       - `state_clip_min` and `state_clip_max` (but only if they are the same 
value)
       
       **Not Implemented**, but probably supported by ONNX (easy to implement):
       - `mode=GRU, RNN_relu, RNN_tanh`
       - `use_sequence_length`
       - custom `states` inputs and `states` outputs
       
       **Not Implemented**, and probably **not** supported by ONNX (hard to 
implement):
       - `state_clip_nan`
       - `state_clip_min != state_clip_max`
       - `dropout`
       - `projection_size`
       - `num_layers`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to