coffezhou opened a new issue, #18003: URL: https://github.com/apache/tvm/issues/18003
### Expected behavior The output of the TVM CPU backend should be identical with those of onnxruntime and TVM CUDA backend. ### Actual behavior For the following model:  when compile the model for the CPU target, the output is: ```c [[[[False]]]] ``` However, the outputs of onnxruntime and TVM CUDA backend are as follows: ```c [[[[ True]]]] ``` ### Environment OS: Ubuntu 20.04 TVM: 0.21.dev0 (3db71bb3a) onnxruntime: 1.21.0 CUDA: 11.8 GPU: NVIDIA GeForce RTX 3080 ### Steps to reproduce This bug can be reproduced by the following code with the model in the attachment. ```python import sys import numpy as np import onnx import onnxruntime import tvm from tvm import relax from tvm.relax.frontend.onnx import from_onnx import pickle def main(): onnx_model = onnx.load("a2631.onnx") shape_onnx_model = onnx.shape_inference.infer_shapes(onnx_model) onnx.save(shape_onnx_model, '1111.onnx') with open("inputs.pkl", "rb") as fp: inputs = pickle.load(fp) try: ort_session = onnxruntime.InferenceSession( onnx_model.SerializeToString(), providers=["CPUExecutionProvider"] ) ort_output = ort_session.run([], inputs) except Exception as e: print(e) sys.exit(1) print("ONNXRuntime:\n", ort_output) # Convert the onnx model into relax through the onnx importer. tvm_model = from_onnx(onnx_model, keep_params_in_input=True) # Convert operators for inference mode. tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) # Legalize any relax ops into tensorir. tvm_model = relax.transform.LegalizeOps()(tvm_model) # Separate model from parameters. tvm_model, params = relax.frontend.detach_params(tvm_model) # Prepare inputs. input_list = [ inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs ] if params: input_list += params["main"] # Compile the relax graph into a VM then run. #----------------------cpu----------------------- with tvm.transform.PassContext(opt_level=2): target = tvm.target.Target("llvm", host="llvm") relax_pipeline = relax.pipeline.get_default_pipeline(target) ex = relax.build(tvm_model, target="llvm", relax_pipeline=relax_pipeline) vm = relax.VirtualMachine(ex, tvm.cpu()) # Run model and check outputs. vm.set_input("main", *input_list) vm.invoke_stateful("main") tvm_cpu_output = vm.get_outputs("main") print("TVM-CPU:\n", tvm_cpu_output) #----------------------cpu----------------------- #----------------------cuda----------------------- with tvm.target.Target("cuda"): tvm_model = tvm.tir.transform.DefaultGPUSchedule()(tvm_model) with tvm.transform.PassContext(opt_level=3): ex = tvm.compile(tvm_model, target="cuda") vm1 = relax.VirtualMachine(ex, tvm.cuda()) vm1.set_input("main", *input_list) vm1.invoke_stateful("main") tvm_gpu_output = vm1.get_outputs("main") print("TVM-CUDA:\n", tvm_gpu_output) #----------------------cuda----------------------- if __name__ == "__main__": main() ``` [testcase.zip](https://github.com/user-attachments/files/20359142/testcase.zip) ### Triage * needs-triage -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
