Zheng-Bicheng commented on PR #16651:
URL: https://github.com/apache/tvm/pull/16651#issuecomment-1970274975
I integrated the inference code for TVM, PaddlePaddle, and ONNX. The code is
as follows:
```python
import paddle
import tvm
from tvm import relay
from tvm.contrib import graph_executor
import numpy as np
import onnx
import onnxruntime as rt
# Model Attr
input_shape = [1, 3, 224, 224]
input_name = "inputs"
def infer_by_paddlepaddle(temp_prefix, temp_input_data):
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
temp_prog, feed_target_names, fetch_targets =
paddle.static.load_inference_model(temp_prefix, exe)
temp_output, = exe.run(temp_prog, feed={feed_target_names[0]:
temp_input_data}, fetch_list=fetch_targets)
return temp_prog, temp_output
def infer_by_onnx(temp_model_path, temp_input_data):
sess = rt.InferenceSession(temp_model_path, None)
temp_input_name = sess.get_inputs()[0].name
out_name = sess.get_outputs()[0].name
temp_onnx_output = sess.run([out_name], {temp_input_name:
temp_input_data})[0]
temp_onnx_model = onnx.load_model(temp_model_path)
return temp_onnx_model, temp_onnx_output
def infer_by_tvm(temp_model, temp_input_data):
if isinstance(temp_model, paddle.static.Program):
# model is loaded by `paddle.static.load_inference_model`
mod, params = relay.frontend.from_paddle(temp_model,
shape_dict={input_name: input_shape})
else:
mod, params = relay.frontend.from_onnx(temp_model,
shape={input_name: input_shape})
with tvm.transform.PassContext(opt_level=5):
lib = relay.build(mod, target="llvm", params=params)
# tvm inference
ctx = tvm.cpu()
tvm_model = graph_executor.GraphModule(lib['default'](ctx))
tvm_model.set_input(input_name, temp_input_data)
tvm_model.run()
tvm_output = tvm_model.get_output(0).asnumpy()
return tvm_output
log_file = "tune.json"
if __name__ == "__main__":
np.random.seed(520)
# create input data
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
paddle_prefix = "MobileNetV1_QAT/inference"
paddle_model, paddle_output = infer_by_paddlepaddle(paddle_prefix,
input_data)
onnx_model_path = "MobileNetV1_QAT/inference.onnx"
onnx_model, onnx_output = infer_by_paddlepaddle(paddle_prefix,
input_data)
# 对比测试Paddle模型和ONNX模型的输出(通过测试)
np.testing.assert_allclose(paddle_output[0], onnx_output[0], rtol=1e-5,
atol=1e-5)
# 测试TVM_Paddle模型和TVM_ONNX模型的输出(通过测试)
tvm_paddle_result = infer_by_tvm(paddle_model, input_data)
tvm_onnx_result = infer_by_tvm(onnx_model, input_data)
np.testing.assert_allclose(tvm_paddle_result[0], tvm_onnx_result[0],
rtol=1e-5, atol=1e-5)
# 测试Paddle模型和TVM_Paddle模型的输出
# np.testing.assert_allclose(tvm_paddle_result[0], paddle_output[0],
rtol=1e-5, atol=1e-5)
# 测试ONNX模型和TVM_ONNX模型的输出
np.testing.assert_allclose(tvm_onnx_result[0], onnx_output[0],
rtol=1e-5, atol=1e-5)
```
I found that when inputting the same data, the output data of the Paddle
model and the ONNX model are consistent.
The differences between TVM and Paddle are as follows:
```text
Mismatched elements: 4 / 1000 (0.4%)
Max absolute difference: 0.01572984
Max relative difference: 1.
x: array([0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. ,...
y: array([0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. ,...
```
The differences between TVM and ONNX are as follows:
```text
Mismatched elements: 4 / 1000 (0.4%)
Max absolute difference: 0.01572984
Max relative difference: 1.
x: array([0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. ,...
y: array([0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 0. ,...
```
Therefore, my initial statement should be considered incorrect; under the
same data conditions, both the Paddle model and the ONNX model exhibit the same
symptoms.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]