coffezhou opened a new issue, #17965: URL: https://github.com/apache/tvm/issues/17965
### Expected behavior TVM should output consistent results for the CPU and GPU targets. ### Actual behavior For the following model:  when compile the model for the CPU target, the output is: ```c cpu: [[[[nan nan nan] [nan nan nan] [nan nan nan]] [[nan nan nan] [nan nan nan] [nan nan nan]] [[nan nan nan] [nan nan nan] [nan nan nan]]]] ``` However, when the target is CUDA, the output is: ```c gpu: [[[[ 9.5653236e-01 8.9820576e-01 8.9820576e-01] [ 9.5653236e-01 -3.4028231e+38 -3.4028231e+38] [-3.4028231e+38 -3.4028231e+38 -3.4028231e+38]] [[ 9.5653236e-01 8.9820576e-01 8.9820576e-01] [ 9.5653236e-01 -3.4028231e+38 -3.4028231e+38] [-3.4028231e+38 -3.4028231e+38 -3.4028231e+38]] [[ 9.5653236e-01 8.9820576e-01 8.9820576e-01] [ 9.5653236e-01 -3.4028231e+38 -3.4028231e+38] [-3.4028231e+38 -3.4028231e+38 -3.4028231e+38]]]] ``` ### Environment OS: Ubuntu 20.04 TVM: 0.21.dev0(bcb68b130) CUDA: 11.8 GPU: NVIDIA GeForce RTX 3080 ### Steps to reproduce This bug can be reproduced by the following code with the model in the attachment. ```python import sys import numpy as np import onnx import onnxruntime import tvm import tvm.testing from tvm import relax from tvm.relax.frontend.onnx import from_onnx import pickle def main(): onnx_model = onnx.load("a249.onnx") with open("inputs.pkl", "rb") as fp: inputs = pickle.load(fp) # Convert the onnx model into relax through the onnx importer. tvm_model = from_onnx(onnx_model, keep_params_in_input=True) # Convert operators for inference mode. tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model) # Legalize any relax ops into tensorir. tvm_model = relax.transform.LegalizeOps()(tvm_model) # Separate model from parameters. tvm_model, params = relax.frontend.detach_params(tvm_model) # Prepare inputs. input_list = [ inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs ] if params: input_list += params["main"] # Compile the relax graph into a VM then run. #----------------------cpu----------------------- with tvm.transform.PassContext(opt_level=0): ex = relax.build(tvm_model, target="llvm") vm = relax.VirtualMachine(ex, tvm.cpu()) # Run model and check outputs. vm.set_input("main", *input_list) vm.invoke_stateful("main") tvm_cpu_output = vm.get_outputs("main") print("cpu: ", tvm_cpu_output) #----------------------cpu----------------------- #----------------------cuda----------------------- with tvm.target.Target("cuda"): tvm_model = tvm.tir.transform.DefaultGPUSchedule()(tvm_model) with tvm.transform.PassContext(opt_level=3): ex = tvm.compile(tvm_model, target="cuda") vm1 = relax.VirtualMachine(ex, tvm.cuda()) vm1.set_input("main", *input_list) vm1.invoke_stateful("main") tvm_gpu_output = vm1.get_outputs("main") print("gpu: ", tvm_gpu_output) #----------------------cuda----------------------- if __name__ == "__main__": main() ``` [testcase.zip](https://github.com/user-attachments/files/20180366/testcase.zip) ### Triage * needs-triage -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
