Zha0q1 commented on a change in pull request #20048:
URL: https://github.com/apache/incubator-mxnet/pull/20048#discussion_r598365615
##########
File path: tests/python-pytest/onnx/test_onnxruntime.py
##########
@@ -988,3 +988,170 @@ def test_ernie_inference_onnxruntime(tmp_path,
model_name):
finally:
shutil.rmtree(tmp_path)
+
+
+@with_seed()
[email protected]('model_name', ['transformer_en_de_512'])
+def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name):
+ tmp_path = str(tmp_path)
+ try:
+ import gluonnlp as nlp
+ dataset = 'WMT2014'
+ ctx = mx.cpu(0)
+ model, _, _ = nlp.model.get_model(
+ name=model_name,
+ ctx=ctx,
+ pretrained=True,
+ dataset_name=dataset)
+
+ model.hybridize(static_alloc=False)
+
+ batch = 7
+ seq_length = 16
+ C_in = 512
+ C_out = 512
+ src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length),
dtype='float32')
+ step_input = mx.nd.random.uniform(0, 36794, shape=(batch,),
dtype='float32')
+ src_valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+
+ encoder_outputs, encoder_additional_outputs = model.encode(src,
+
valid_length=src_valid_length)
+
+ decoder_states =
model.decoder.init_state_from_encoder(encoder_outputs, src_valid_length)
+
+ step_output, states, additional_outputs =
model.decode_step(step_input, decoder_states)
+
+ # skip export of 'decoder' as it's used for training only
+ for component in ['encoder', 'one_step_ahead_decoder', 'src_embed',
'tgt_embed',
+ 'tgt_proj']:
+
+ prefix = "%s/%s" %(tmp_path, component)
+ component = getattr(model, component)
+ component.export(prefix)
+ sym_file = "%s-symbol.json" % prefix
+ params_file = "%s-0000.params" % prefix
+ onnx_file = "%s.onnx" % prefix
+
+ def export_to_onnx(prefix, input_shapes, input_types, **kwargs):
+ sym_file = "%s-symbol.json" % prefix
+ params_file = "%s-0000.params" % prefix
+ onnx_file = "%s.onnx" % prefix
+ return mx.contrib.onnx.export_model(sym_file, params_file,
input_shapes, input_types,
+ onnx_file, **kwargs)
+
+ def onnx_runtime_predict(onnx_file, onnx_inputs):
+ ses_opt = onnxruntime.SessionOptions()
+ ses_opt.log_severity_level = 3
+ session = onnxruntime.InferenceSession(onnx_file, ses_opt)
+ input_dict = dict((session.get_inputs()[i].name,
onnx_inputs[i].asnumpy())
+ for i in range(len(onnx_inputs)))
+ return session.run(None, input_dict)
+
+ def verify_encoder():
+ inputs = mx.nd.random.uniform(-1, 1, shape=(batch, seq_length,
C_in), dtype='float32')
+ valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+ pred = model.encoder(inputs, valid_length=valid_length)
+
+ prefix = "%s/encoder" %tmp_path
+ input_shapes = [(batch, seq_length, C_in), (batch,)]
+ input_types = [np.float32, np.float32]
+ onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+ onnx_inputs = [inputs, valid_length]
+ pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+ assert_almost_equal(pred[0], pred_onx[0])
+
+ def verify_src_embed():
+ src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length),
dtype='float32')
+ pred = model.src_embed(src)
+
+ prefix = "%s/src_embed" %tmp_path
+ input_shapes = [(batch, seq_length)]
+ input_types = [np.float32]
+ onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+ onnx_inputs = [src]
+ pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+ assert_almost_equal(pred, pred_onx[0])
+
+ def verify_tgt_embed():
+ tgt = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length),
dtype='float32')
+ pred = model.tgt_embed(tgt)
+
+ prefix = "%s/tgt_embed" %tmp_path
+ input_shapes = [(batch, seq_length)]
+ input_types = [np.float32]
+ onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+ onnx_inputs = [tgt]
+ pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+ assert_almost_equal(pred, pred_onx[0])
+
+ def verify_tgt_proj():
+ decoder_out = mx.nd.random.uniform(0, 512, shape=(batch,
seq_length, C_out),
+ dtype='float32')
+ pred = model.tgt_proj(decoder_out)
+
+ prefix = "%s/tgt_proj" %tmp_path
+ input_shapes = [(batch, seq_length, C_out)]
+ input_types = [np.float32]
+ onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+ onnx_inputs = [decoder_out]
+ pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+ assert_almost_equal(pred, pred_onx[0], rtol=1.e-04, atol=1.5e-03)
+
+ def verify_one_step_ahead_decoder():
+ prefix = "%s/one_step_ahead_decoder" %tmp_path
+
+ # the input data order
+ perm = [2, 0, 1]
Review comment:
I used a perm list so that the actual in_shapes an in_types list can
have the same order as passed in the native model. It's just the converted onnx
takes them in a different order some how. I think this is more consistent, what
do you think?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]