waytrue17 commented on a change in pull request #20048:
URL: https://github.com/apache/incubator-mxnet/pull/20048#discussion_r598031213



##########
File path: tests/python-pytest/onnx/test_onnxruntime.py
##########
@@ -988,3 +988,170 @@ def test_ernie_inference_onnxruntime(tmp_path, 
model_name):
 
     finally:
         shutil.rmtree(tmp_path)
+
+
+@with_seed()
[email protected]('model_name', ['transformer_en_de_512'])
+def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name):
+    tmp_path = str(tmp_path)
+    try:
+        import gluonnlp as nlp
+        dataset = 'WMT2014'
+        ctx = mx.cpu(0)
+        model, _, _ = nlp.model.get_model(
+            name=model_name,
+            ctx=ctx,
+            pretrained=True,
+            dataset_name=dataset)
+
+        model.hybridize(static_alloc=False)
+
+        batch = 7
+        seq_length = 16
+        C_in = 512

Review comment:
       What are C_in and C_out? Should we also test when `C_in != C_out`?

##########
File path: tests/python-pytest/onnx/test_onnxruntime.py
##########
@@ -988,3 +988,170 @@ def test_ernie_inference_onnxruntime(tmp_path, 
model_name):
 
     finally:
         shutil.rmtree(tmp_path)
+
+
+@with_seed()
[email protected]('model_name', ['transformer_en_de_512'])
+def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name):
+    tmp_path = str(tmp_path)
+    try:
+        import gluonnlp as nlp
+        dataset = 'WMT2014'
+        ctx = mx.cpu(0)
+        model, _, _ = nlp.model.get_model(
+            name=model_name,
+            ctx=ctx,
+            pretrained=True,
+            dataset_name=dataset)
+
+        model.hybridize(static_alloc=False)
+
+        batch = 7
+        seq_length = 16
+        C_in = 512
+        C_out = 512
+        src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), 
dtype='float32')

Review comment:
       Curious, does not `src` need to be int type?

##########
File path: tests/python-pytest/onnx/test_onnxruntime.py
##########
@@ -988,3 +988,170 @@ def test_ernie_inference_onnxruntime(tmp_path, 
model_name):
 
     finally:
         shutil.rmtree(tmp_path)
+
+
+@with_seed()
[email protected]('model_name', ['transformer_en_de_512'])
+def test_transformer_pretrained_inference_onnxruntime(tmp_path, model_name):
+    tmp_path = str(tmp_path)
+    try:
+        import gluonnlp as nlp
+        dataset = 'WMT2014'
+        ctx = mx.cpu(0)
+        model, _, _ = nlp.model.get_model(
+            name=model_name,
+            ctx=ctx,
+            pretrained=True,
+            dataset_name=dataset)
+
+        model.hybridize(static_alloc=False)
+
+        batch = 7
+        seq_length = 16
+        C_in = 512
+        C_out = 512
+        src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), 
dtype='float32')
+        step_input = mx.nd.random.uniform(0, 36794, shape=(batch,), 
dtype='float32')
+        src_valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+
+        encoder_outputs, encoder_additional_outputs = model.encode(src,
+                                                                   
valid_length=src_valid_length)
+
+        decoder_states = 
model.decoder.init_state_from_encoder(encoder_outputs, src_valid_length)
+
+        step_output, states, additional_outputs = 
model.decode_step(step_input, decoder_states)
+
+        # skip export of 'decoder' as it's used for training only
+        for component in ['encoder', 'one_step_ahead_decoder', 'src_embed', 
'tgt_embed',
+                         'tgt_proj']:
+
+            prefix = "%s/%s" %(tmp_path, component)
+            component = getattr(model, component)
+            component.export(prefix)
+            sym_file = "%s-symbol.json" % prefix
+            params_file = "%s-0000.params" % prefix
+            onnx_file = "%s.onnx" % prefix
+
+        def export_to_onnx(prefix, input_shapes, input_types, **kwargs):
+            sym_file = "%s-symbol.json" % prefix
+            params_file = "%s-0000.params" % prefix
+            onnx_file = "%s.onnx" % prefix
+            return mx.contrib.onnx.export_model(sym_file, params_file, 
input_shapes, input_types,
+                                                onnx_file, **kwargs)
+
+        def onnx_runtime_predict(onnx_file, onnx_inputs):
+            ses_opt = onnxruntime.SessionOptions()
+            ses_opt.log_severity_level = 3
+            session = onnxruntime.InferenceSession(onnx_file, ses_opt)
+            input_dict = dict((session.get_inputs()[i].name, 
onnx_inputs[i].asnumpy())
+                            for i in range(len(onnx_inputs)))
+            return session.run(None, input_dict)
+
+        def verify_encoder():
+            inputs = mx.nd.random.uniform(-1, 1, shape=(batch, seq_length, 
C_in), dtype='float32')
+            valid_length = mx.nd.array([seq_length] * batch, dtype='float32')
+            pred = model.encoder(inputs, valid_length=valid_length)
+
+            prefix = "%s/encoder" %tmp_path
+            input_shapes = [(batch, seq_length, C_in), (batch,)]
+            input_types = [np.float32, np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [inputs, valid_length]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred[0], pred_onx[0])
+
+        def verify_src_embed():
+            src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), 
dtype='float32')
+            pred = model.src_embed(src)
+
+            prefix = "%s/src_embed" %tmp_path
+            input_shapes = [(batch, seq_length)]
+            input_types = [np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [src]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred, pred_onx[0])
+
+        def verify_tgt_embed():
+            tgt = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), 
dtype='float32')
+            pred = model.tgt_embed(tgt)
+
+            prefix = "%s/tgt_embed" %tmp_path
+            input_shapes = [(batch, seq_length)]
+            input_types = [np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [tgt]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred, pred_onx[0])
+
+        def verify_tgt_proj():
+            decoder_out = mx.nd.random.uniform(0, 512, shape=(batch, 
seq_length, C_out),
+                                               dtype='float32')
+            pred = model.tgt_proj(decoder_out)
+
+            prefix = "%s/tgt_proj" %tmp_path
+            input_shapes = [(batch, seq_length, C_out)]
+            input_types = [np.float32]
+            onnx_file = export_to_onnx(prefix, input_shapes, input_types)
+            onnx_inputs = [decoder_out]
+            pred_onx = onnx_runtime_predict(onnx_file, onnx_inputs)
+
+            assert_almost_equal(pred, pred_onx[0], rtol=1.e-04, atol=1.5e-03)
+
+        def verify_one_step_ahead_decoder():
+            prefix = "%s/one_step_ahead_decoder" %tmp_path
+
+            # the input data order
+            perm = [2, 0, 1]

Review comment:
       Could we put the correct order when instantiating the list instead of 
using perm?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to