This is an automated email from the ASF dual-hosted git repository.
zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new cfa1c89 [v1.x] Add AWDRNN Pratrained model test (#20018)
cfa1c89 is described below
commit cfa1c890a7ecb8b5e29ff4e90d6784141f09c4cd
Author: Zhaoqi Zhu <[email protected]>
AuthorDate: Thu Mar 18 17:47:40 2021 -0700
[v1.x] Add AWDRNN Pratrained model test (#20018)
* awd lstm
* add test
* add seq_len
* small fixes
* Update _op_translations.py
* Update test_onnxruntime.py
---
python/mxnet/contrib/onnx/mx2onnx/export_model.py | 1 +
tests/python-pytest/onnx/test_onnxruntime.py | 62 +++++++++++++++++++++++
2 files changed, 63 insertions(+)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_model.py
b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
index 60e6a34..1c50db5 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_model.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_model.py
@@ -94,6 +94,7 @@ def export_model(sym, params, in_shapes=None,
in_types=np.float32,
if not isinstance(in_types, list):
in_types = [in_types for _ in range(len(in_shapes))]
in_types_t = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(i_t)] for i_t in
in_types]
+ assert len(in_types) == len(in_shapes), "The lengths of in_types and
in_shapes must equal"
# if input parameters are strings(file paths), load files and create
symbol parameter objects
if isinstance(sym, string_types) and isinstance(params, string_types):
logging.info("Converting json and weight file to sym and params")
diff --git a/tests/python-pytest/onnx/test_onnxruntime.py
b/tests/python-pytest/onnx/test_onnxruntime.py
index fa45bb6..e2a8329 100644
--- a/tests/python-pytest/onnx/test_onnxruntime.py
+++ b/tests/python-pytest/onnx/test_onnxruntime.py
@@ -870,6 +870,68 @@ def
test_dynamic_shape_bert_inference_onnxruntime(tmp_path, model):
@with_seed()
[email protected]('model_name', [('awd_lstm_lm_600', 600),
('awd_lstm_lm_1150', 1150)])
[email protected]('seq_length', [16, 128, 256])
+def test_awd_rnn_lstm_pretrained_inference_onnxruntime(tmp_path, model_name,
seq_length):
+ try:
+ import gluonnlp as nlp
+ ctx = mx.cpu()
+ dataset= 'wikitext-2'
+ model, _ = nlp.model.get_model(
+ name=model_name[0],
+ ctx=ctx,
+ pretrained=True,
+ dataset_name=dataset,
+ dropout=0)
+ model.hybridize()
+
+ batch = 2
+ num_hidden = model_name[1]
+ num_layers = 2
+ inputs = mx.nd.random.randint(0, 33278, shape=(seq_length, batch),
+ ctx=ctx).astype('float32')
+ begin_state = model.begin_state(func=mx.nd.random.uniform, low=0,
high=1,
+ batch_size=batch, dtype='float32',
ctx=ctx)
+ out, out_state= model(inputs, begin_state)
+
+ prefix = "%s/awd_lstm" % tmp_path
+ model.export(prefix)
+ sym_file = "%s-symbol.json" % prefix
+ params_file = "%s-0000.params" % prefix
+ onnx_file = "%s.onnx" % prefix
+
+ input_shapes = [(seq_length, batch),
+ np.shape(begin_state[0][0]),
np.shape(begin_state[0][1]),
+ np.shape(begin_state[1][0]),
np.shape(begin_state[1][1]),
+ np.shape(begin_state[2][0]),
np.shape(begin_state[2][1])]
+ input_types = [np.float32, np.float32, np.float32, np.float32,
np.float32, np.float32,
+ np.float32]
+ converted_model_path = mx.contrib.onnx.export_model(sym_file,
params_file, input_shapes,
+ input_types,
onnx_file, verbose=True)
+
+ sess_options = onnxruntime.SessionOptions()
+ sess_options.graph_optimization_level =
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+ sess = onnxruntime.InferenceSession(onnx_file, sess_options)
+
+ in_tensors = [inputs, begin_state[0][0], begin_state[0][1],
+ begin_state[1][0], begin_state[1][1],
+ begin_state[2][0], begin_state[2][1]]
+ input_dict = dict((sess.get_inputs()[i].name, in_tensors[i].asnumpy())
for i in range(len(in_tensors)))
+ pred = sess.run(None, input_dict)
+
+ assert_almost_equal(out, pred[6])
+ assert_almost_equal(out_state[0][0], pred[0])
+ assert_almost_equal(out_state[0][1], pred[1])
+ assert_almost_equal(out_state[1][0], pred[2])
+ assert_almost_equal(out_state[1][1], pred[3])
+ assert_almost_equal(out_state[2][0], pred[4])
+ assert_almost_equal(out_state[2][1], pred[5])
+
+ finally:
+ shutil.rmtree(tmp_path)
+
+
+@with_seed()
@pytest.mark.parametrize('model_name', ['ernie_12_768_12'])
def test_ernie_inference_onnxruntime(tmp_path, model_name):
tmp_path = str(tmp_path)