coffezhou opened a new issue, #17854:
URL: https://github.com/apache/tvm/issues/17854

   ### Expected behavior
   
   TVM should compile the model correctly with the CUDA backend.
   
   ### Actual behavior
   
   When compiling the model with the CUDA backend, TVM crashes as follows:
   ```c
   Traceback (most recent call last):
     File "/home/carla/Documents/test/test.py", line 140, in check_correctness
       ex = tvm.compile(tvm_model, target="cuda")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "/home/carla/Documents/tvm/python/tvm/driver/build_module.py", line 
104, in compile
       return tvm.relax.build(
              ^^^^^^^^^^^^^^^^
     File "/home/carla/Documents/tvm/python/tvm/relax/vm_build.py", line 259, 
in build
       return _vmlink(
              ^^^^^^^^
     File "/home/carla/Documents/tvm/python/tvm/relax/vm_build.py", line 154, 
in _vmlink
       lib = tvm.tir.build(tir_mod, target=target, pipeline=tir_pipeline)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "/home/carla/Documents/tvm/python/tvm/tir/build.py", line 186, in 
build
       return tir_to_runtime(host_mod, device_mod_dict, target_host)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "/home/carla/Documents/tvm/python/tvm/tir/build.py", line 96, in 
tir_to_runtime
       device_modules.append(codegen_build(device_mod, target))
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "/home/carla/Documents/tvm/python/tvm/tir/build.py", line 80, in 
codegen_build
       return bf(mod, target)
              ^^^^^^^^^^^^^^^
     File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in 
tvm._ffi._cy3.core.PackedFuncBase.__call__
     File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in 
tvm._ffi._cy3.core.FuncCall
     File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in 
tvm._ffi._cy3.core.FuncCall3
     File "tvm/_ffi/_cython/./base.pxi", line 185, in 
tvm._ffi._cy3.core.CHECK_CALL
     File "/home/carla/Documents/tvm/python/tvm/_ffi/base.py", line 468, in 
raise_last_ffi_error
       raise py_err
     File "/home/carla/Documents/tvm/src/target/opt/build_cuda_on.cc", line 
161, in tvm::codegen::BuildCUDA(tvm::IRModule, tvm::Target)
       ptx = (*f)(code, target).operator std::string();
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in 
tvm._ffi._cy3.core.tvm_callback
     File "/home/carla/Documents/tvm/python/tvm/contrib/nvcc.py", line 204, in 
tvm_callback_cuda_compile
       ptx = compile_cuda(code, target_format="fatbin")
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "/home/carla/Documents/tvm/python/tvm/contrib/nvcc.py", line 128, in 
compile_cuda
       raise RuntimeError(msg)
   RuntimeError: 
   #if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \
        (__CUDACC_VER_MAJOR__ > 11))
   #define TVM_ENABLE_L2_PREFETCH 1
   #else
   #define TVM_ENABLE_L2_PREFETCH 0
   #endif
   
   #ifdef _WIN32
     using uint = unsigned int;
     using uchar = unsigned char;
     using ushort = unsigned short;
     using int64_t = long long;
     using uint64_t = unsigned long long;
   #else
     #define uint unsigned int
     #define uchar unsigned char
     #define ushort unsigned short
     #define int64_t long long
     #define uint64_t unsigned long long
   #endif
   extern "C" __global__ void __launch_bounds__(1024) add1_kernel(float* 
__restrict__ T_add, float* __restrict__ lv6, float* __restrict__ lv7, int64_t );
   extern "C" __global__ void __launch_bounds__(1024) add1_kernel(float* 
__restrict__ T_add, float* __restrict__ lv6, float* __restrict__ lv7, int64_t ) 
{
     for (int64_t ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < ((( * (int64_t)3) + 
(int64_t)262143) >> (int64_t)18); ++ax0_ax1_fused_0) {
       if ((((ax0_ax1_fused_0 * (int64_t)262144) + (((int64_t)blockIdx.x) * 
(int64_t)1024)) + ((int64_t)threadIdx.x)) < ( * (int64_t)3)) {
         T_add[(((ax0_ax1_fused_0 * (int64_t)262144) + (((int64_t)blockIdx.x) * 
(int64_t)1024)) + ((int64_t)threadIdx.x))] = (lv6[(((ax0_ax1_fused_0 * 
(int64_t)262144) + (((int64_t)blockIdx.x) * (int64_t)1024)) + 
((int64_t)threadIdx.x))] + lv7[(((ax0_ax1_fused_0 * (int64_t)262144) + 
(((int64_t)blockIdx.x) * (int64_t)1024)) + ((int64_t)threadIdx.x))]);
       }
     }
   }
   
   
   Compilation error:
   /tmp/tmpjs6_cjiu/tvm_kernels.cu(24): error: operand of "*" must be a pointer 
but has type "long long"
   
   /tmp/tmpjs6_cjiu/tvm_kernels.cu(25): error: operand of "*" must be a pointer 
but has type "long long"
   
   2 errors detected in the compilation of "/tmp/tmpjs6_cjiu/tvm_kernels.cu".
   
   
   During handling of the above exception, another exception occurred:
   
   Traceback (most recent call last):
     File "/home/carla/Documents/test_tvm/test-tvm-llm/test_tvm.py", line 189, 
in <module>
       main()
     File "/home/carla/Documents/test_tvm/test-tvm-llm/test_tvm.py", line 178, 
in main
       check_correctness(onnx_model)
     File "/home/carla/Documents/test_tvm/test-tvm-llm/test_tvm.py", line 136, 
in check_correctness
       with tvm.target.Target("cuda"):
     File "/home/carla/Documents/tvm/python/tvm/target/target.py", line 145, in 
__exit__
       _ffi_api.TargetExitScope(self)
     File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in 
tvm._ffi._cy3.core.PackedFuncBase.__call__
     File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in 
tvm._ffi._cy3.core.FuncCall
     File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in 
tvm._ffi._cy3.core.FuncCall3
     File "tvm/_ffi/_cython/./base.pxi", line 185, in 
tvm._ffi._cy3.core.CHECK_CALL
     File "/home/carla/Documents/tvm/python/tvm/_ffi/base.py", line 468, in 
raise_last_ffi_error
       raise py_err
     File "/home/carla/Documents/tvm/src/target/target.cc", line 52, in 
tvm::TargetInternal::ExitScope(tvm::Target)
       static void ExitScope(Target target) { target.ExitWithScope(); }
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "/home/carla/Documents/tvm/src/target/target.cc", line 744, in 
tvm::Target::ExitWithScope()
       ICHECK(entry->context_stack.top().same_as(*this));
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
   tvm.error.InternalError: Traceback (most recent call last):
     1: tvm::TargetInternal::ExitScope(tvm::Target)
           at /home/carla/Documents/tvm/src/target/target.cc:52
     0: tvm::Target::ExitWithScope()
           at /home/carla/Documents/tvm/src/target/target.cc:744
     File "/home/carla/Documents/tvm/src/target/target.cc", line 744
   InternalError: Check failed: (entry->context_stack.top().same_as(*this)) is 
false: 
   
   ```
   
   ### Environment
   
   OS: Ubuntu 20.04
   TVM: 0.21.dev0(c00f52a70)
   
   ### Steps to reproduce
   
   This bug can be reproduced by the following code with the model in the 
attachment. As shown in the code, the model can be executed by onnxruntime and 
also be compiled by tvm with cpu backend. However, tvm failed to compile this 
model with CUDA backend.
   ```python
   import sys
   
   import numpy as np
   import onnx
   import onnxruntime
   
   import tvm
   from tvm import relax
   from tvm.relax.frontend.onnx import from_onnx
   
   import argparse
   import pickle
   
               
   def main():
       onnx_model = onnx.load("a2.onnx")
       
       with open("inputs.pkl", "rb") as fp:
           inputs = pickle.load(fp)
       
       try:
           ort_session = onnxruntime.InferenceSession(
               onnx_model.SerializeToString(), 
providers=["CPUExecutionProvider"]
           )
           ort_output = ort_session.run([], inputs)
       except Exception as e:
           print(e)
           sys.exit(1)
           
       # Convert the onnx model into relax through the onnx importer.
       tvm_model = from_onnx(onnx_model, keep_params_in_input=True)
       # Convert operators for inference mode.
       tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
       # Legalize any relax ops into tensorir.
       tvm_model = relax.transform.LegalizeOps()(tvm_model)
   
       # Separate model from parameters.
       tvm_model, params = relax.frontend.detach_params(tvm_model)
       
       # Prepare inputs.
       input_list = [
           inputs[key.name_hint] for key in tvm_model["main"].params if 
key.name_hint in inputs
       ]
       if params:
           input_list += params["main"]
           
       # Compile the relax graph into a VM then run.
       #----------------------cpu-----------------------
       with tvm.transform.PassContext(opt_level=3):
           target = tvm.target.Target("llvm", host="llvm")
           relax_pipeline = relax.pipeline.get_default_pipeline(target)
           
           ex = relax.build(tvm_model, target="llvm", 
relax_pipeline=relax_pipeline)
           vm = relax.VirtualMachine(ex, tvm.cpu())
       
           # Run model and check outputs.
           vm.set_input("main", *input_list)
           vm.invoke_stateful("main")
           tvm_cpu_output = vm.get_outputs("main")
       #----------------------cpu-----------------------
       
       #----------------------cuda-----------------------
       with tvm.target.Target("cuda"):
           tvm_model = tvm.tir.transform.DefaultGPUSchedule()(tvm_model) 
   
           with tvm.transform.PassContext(opt_level=3):
               ex = tvm.compile(tvm_model, target="cuda")
               vm1 = relax.VirtualMachine(ex, tvm.cuda())
               
           vm1.set_input("main", *input_list)
           vm1.invoke_stateful("main")
           tvm_gpu_output = vm1.get_outputs("main")
       #----------------------cuda-----------------------
       
       
   if __name__ == "__main__":
       main()
   ```
   
   
[testcase.zip](https://github.com/user-attachments/files/19809073/testcase.zip)
   
   ### Triage
   
   * needs-triage
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to