RuRo opened a new pull request #17734: [MXNET-889] Implement ONNX export for gluon LSTM. URL: https://github.com/apache/incubator-mxnet/pull/17734 ## Description ## Implements 3 new node conversion functions, to facilitate the conversion of `gluon.rnn.LSTM` blocks to ONNX. 1) `_zeros` (as well as `_ones` and `_full` as a bonus) `gluon.rnn.LSTM` uses `sym.zeros` to [initialize the initial recurrent states](https://github.com/apache/incubator-mxnet/blob/10a12d59f67ad21032a94b1721aaf9b96fddac85/python/mxnet/gluon/rnn/rnn_layer.py#L234). 2) `_rnn_param_concat` which is used by mxnet to [concatenate all the flattened parameters](https://github.com/apache/incubator-mxnet/blob/10a12d59f67ad21032a94b1721aaf9b96fddac85/python/mxnet/gluon/rnn/rnn_layer.py#L278), before passing them to the `RNN` node. 3) `RNN` which is the backend op for the actual `LSTM` computation. I've also implemented/updated a couple helper functions in the [`_op_translations.py`](https://github.com/RuRo/incubator-mxnet/blob/master/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py) file in a separate commit to hopefully reduce the boilerplate amount. ## Checklist ## ### Essentials ### Please feel free to remove inapplicable items for your PR. - [x] The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant [JIRA issue](https://issues.apache.org/jira/projects/MXNET/issues) created (except PRs with tiny changes) - [x] Changes are complete (i.e. I finished coding on this PR) - [ ] All changes have test coverage: - Unit tests are added for small changes to verify correctness (e.g. adding a new operator) - Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore) - Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL) - [x] Code is well-documented: - For user-facing API changes, API doc string has been updated. - For new C++ functions in header files, their functionalities and arguments are documented. - For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable - Check the API doc at https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html - [x] To the best of my knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change ## Tests ## I didn't add any tests for this node yet, since I am not 100% sure, how exactly to do that. Some other PRs with ONNX exports, just add the new node names to a list in [`tests/python-pytest/onnx/test_cases.py`](https://github.com/apache/incubator-mxnet/blob/10a12d59f67ad21032a94b1721aaf9b96fddac85/tests/python-pytest/onnx/test_cases.py#L96), however, I don't see, how would this add any actual tests, without a reference ONNX implementation available. I would gladly add some unittests, if someone can point the right direction though. For now, I've verified the correctness of my conversion by 1) Running this script: <details> <summary>test_rnn_export_onnx.py</summary> ```python #!/bin/env python3 import numpy as np import mxnet as mx import onnxruntime as rt import logging import shutil import os logging.getLogger().setLevel(logging.INFO) # Configuration path = 'lstm-model' seq_len = 10 batch_size = 32 input_size = 16 hidden_size = 8 bidirectional = True # Create input data shape = (seq_len, batch_size, input_size) I = mx.nd.random.randn(*shape) # Create LSTM network seq = mx.gluon.nn.HybridSequential() with seq.name_scope(): seq.add(mx.gluon.rnn.LSTM(hidden_size, bidirectional=bidirectional)) seq.initialize() seq(I) # Init all parameters with random values for k, p in seq.collect_params().items(): p.data()[...] = np.random.randn(*p.data().shape) seq.hybridize(static_alloc=True, static_shape=True) res_mx = seq(I) res_mx = res_mx.asnumpy() # Export to symbol and then to ONNX and forward with onnxruntime seq.export(path) mx.contrib.onnx.export_model( f'{path}-symbol.json', f'{path}-0000.params', [shape], np.float32, onnx_file_path=f'{path}.onnx', verbose=True ) # Forward, using onnxruntime sess = rt.InferenceSession(f'{path}.onnx') res_onnx, = sess.run(None, {'data': I.asnumpy()}) # Check results diff = np.abs(res_mx - res_onnx) print(f"{diff.mean()=}") print(f"{diff.max()=}") np.testing.assert_almost_equal(res_mx, res_onnx, decimal=4) # decimal=4 is the same as mxnet unittests for other ONNX ops # Clean up files os.remove(f"{path}-symbol.json") os.remove(f"{path}-0000.params") os.remove(f"{path}.onnx") ``` </details> Which generates a model with an `LSTM` node, fills all the parameters with random data, exports it to a `symbol.json` and `0000.params` and then converts the model to `ONNX`. It then loads the model using `onnxruntime` and forwards some random data through both the `mxnet` as well as the `onnx` models and asserts, that the results are close. 2) Using this conversion on a CRNN network, I trained at my job and verifying that the prediction results are the same on a subset of out test set. (unfortunately, I can't give too much details about this model, so you'll have to trust me here) ## Comments ## There are a few details in my current implementation that should be noted: 1) `_zeros` - It seems, that `mx.sym` can broadcast a `_zeros/_ones/_full` node along an axis with dimension 0 as if it were a 1, while `ONNX` can't do that. ```python import mxnet as mx a = mx.sym.zeros((10, 0, 30)) b = mx.sym.Variable('b') (a * b).eval(b=mx.nd.zeros((10, 20, 30)))[0].shape ``` Currently, my implementation converts all the 0 dims to 1s. Is this correct? - The version of `ONNX`, that mxnet uses doesn't support some newer nodes like `ConstantOfShape` and I was not able to find a working onnx implementation, which actually supported the currently deprecated `ConstantFill` op, so I had to actually store the `_zeros` data in a tensor, which ends up being stored in the onnx model file.<br/> This also applies to the helper funtions and the other nodes. I've had to settle for some "less than elegant" solutions in some cases, because of the lacking opset support. I might have also missed a better way to do some operations, so if you have any better alternatives, I would gladly change my implementation.<br/> 2) `_rnn_param_concat` You can read more about my implementation in the code (comments), but the general outline is that since `ONNX` expects the `RNN` parameters prepared in a completely different way compared to mxnet, I've had to employ some "dark" magic, involving `convert_rnn_param_concat` generating multiple nodes and `convert_RNN` 'guessing' these node names, using regular expressions.<br/> As a result, the current conversion mechanism only works, if the parameters of the RNN block are concatenated using `_rnn_param_concat` (and not for example `Concat`). If the conversion of of `mx.sym.RNN` (without `gluon`) is a widely requested feature, we would have to convert the `_rnn_param_concat` node to a regular `Concat`, and then split the parameters back apart in the `RNN` conversion, which is not very nice IMO.<br/> 3) `RNN` Even though, this PR implements the conversion of the `RNN` node, currently only the `LSTM` mode is supported and many of the "customization" options are also not implemented. **Implemented:** - `mode=LSTM` - `layout='TNC'` - `bidirectional=True` and `False` - `state_clip_min` and `state_clip_max` (but only if they are the same value) **Not Implemented**, but probably supported by ONNX (easy to implement): - `mode=GRU, RNN_relu, RNN_tanh` - `use_sequence_length` - custom `states` inputs and `states` outputs **Not Implemented**, and probably **not** supported by ONNX (hard to implement): - `state_clip_nan` - `state_clip_min != state_clip_max` - `dropout` - `projection_size` - `num_layers`
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
