sxjscience opened a new pull request #6696:
URL: https://github.com/apache/incubator-tvm/pull/6696
Fix the MXNet 2.0 integration in relay. Tested the BERT and ALBERT model in
the new GluonNLP v1 and has passed the test. Will later on add unittests in
GluonNLP to ensure that most backbones can be run with the graph runtime.
```python
import mxnet as mx
import numpy as np
import gluonnlp
from gluonnlp.models import get_backbone
import numpy.testing as npt
mx.npx.set_np()
model_cls, cfg, tokenizer, backbone_param_path, _ =
get_backbone('google_albert_base_v2')
model = model_cls.from_cfg(cfg)
model.load_parameters(backbone_param_path)
model.hybridize()
batch_size = 1
seq_length = 128
token_ids = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size,
seq_length), dtype=np.int32)
token_types = mx.np.random.randint(0, 2, (batch_size, seq_length),
dtype=np.int32)
valid_length = mx.np.random.randint(seq_length // 2, seq_length,
(batch_size,), dtype=np.int32)
mx_out = model(token_ids, token_types, valid_length)
import tvm
from tvm import relay
import tvm.contrib.graph_runtime as runtime
shape_dict = {
'data0': (batch_size, seq_length),
'data1': (batch_size, seq_length),
'data2': (batch_size,)
}
dtype_dict = {
'data0': 'int32',
'data1': 'int32',
'data2': 'int32'
}
sym = model._cached_graph[1]
params = {}
for k, v in model.collect_params().items():
params[v._var_name] = tvm.nd.array(v.data().asnumpy())
mod, params = relay.frontend.from_mxnet(sym, shape=shape_dict,
dtype=dtype_dict, arg_params=params)
print(mod)
# G4
target = "cuda -model=t4"
with relay.build_config(opt_level=3, required_pass=["FastMath"]):
graph, lib, cparams = relay.build(mod, target, params=params)
ctx = tvm.gpu()
rt = runtime.create(graph, lib, ctx)
rt.set_input(**cparams)
rt.set_input(data0=token_ids, data1=token_types, data2=valid_length)
rt.run()
for i in range(rt.get_num_outputs()):
out = rt.get_output(i)
print(out.asnumpy())# verify the correctness
npt.assert_allclose(out.asnumpy(), mx_out[i].asnumpy(), rtol=1e-3,
atol=1e-2)
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]