KangHe000 opened a new issue, #13639:
URL: https://github.com/apache/tvm/issues/13639

   ### Expected behavior
   
   Hi, I try to run the test_cuda_tensor_core() of test_auto_tensorize.py: 
https://github.com/apache/tvm/blob/main/tests/python/integration/test_auto_tensorize.py
   
   but got cudaFree() error below. 
   
   ### Actual behavior
   
   ```
   One or more operators have not been tuned. Please tune your model for better 
performance. Use DEBUG logging level to see more details.
   terminate called after throwing an instance of 'tvm::runtime::InternalError'
     what():  [21:57:23] 
/home/hekang/gitlab_src/tvm/src/runtime/cuda/cuda_device_api.cc:135: 
   ---------------------------------------------------------------
   An error occurred during the execution of TVM.
   For more information, please see: https://tvm.apache.org/docs/errors.html
   ---------------------------------------------------------------
     Check failed: (e == cudaSuccess || e == cudaErrorCudartUnloading) is 
false: CUDA: unspecified launch failure
   Stack trace:
     0: _ZN3tvm7runtime6detail
     1: tvm::runtime::CUDADeviceAPI::FreeDataSpace(DLDevice, void*)
     2: tvm::runtime::NDArray::Internal::DefaultDeleter(tvm::runtime::Object*)
     3: tvm::runtime::GraphExecutor::~GraphExecutor()
     4: _ZN3tvm7runtime18SimpleObjAllocator7
     5: 
tvm::runtime::GraphExecutorFactory::ExecutorCreate(std::vector<DLDevice, 
std::allocator<DLDevice> > const&)
     6: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::GraphExecutorFactory::GetFunction(std::__cxx11::basic_string<char,
 std::char_traits<char>, std::allocator<char> > const&, 
tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0> 
>::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)
     7: TVMFuncCall
   
   
   Aborted (core dumped)
   ```
   
   ### Environment
   
   target = tvm.target.Target("nvidia/geforce-rtx-2080-ti")
   ubuntu18.04 
   cuda102 
   torch1.12.1+cu102
   
   ### code to reproduce
   
   ```
   """Integration test for MetaSchedule's auto tensorization."""
   import tempfile
   import numpy as np
   import pytest
   import tvm
   import tvm.testing
   import tvm.topi.testing
   from tvm import meta_schedule as ms
   from tvm import relay
   from tvm.meta_schedule.testing import relay_workload
   
   def test_cuda_tensor_core(model_name, input_shape):
       """Integration tests of auto tensorization with CUDA tensor core"""
       target = tvm.target.Target("nvidia/geforce-rtx-2080-ti")
       dev = tvm.cuda()
       if model_name.startswith("bert"):
           data = tvm.nd.array(np.random.randint(0, 30521, size=input_shape), 
dev)  # embedding size
       else:
           data = tvm.nd.array(np.random.randn(*input_shape).astype("float32"), 
dev)
   
       mod, params, (input_name, _, _) = relay_workload.get_network(model_name, 
input_shape)
       seq = tvm.transform.Sequential(
           [
               relay.transform.ToMixedPrecision(),
           ]
       )
       with tvm.transform.PassContext(opt_level=3):
           mod = seq(mod)
   
       def convert_layout(mod):
           seq = tvm.transform.Sequential(
               [relay.transform.ConvertLayout({"nn.conv2d": ["NHWC", "OHWI"]})]
           )
           with tvm.transform.PassContext(opt_level=3):
               mod = seq(mod)
           return mod
       with tempfile.TemporaryDirectory() as work_dir:
           with ms.Profiler() as profiler:
               converted_mod = convert_layout(mod)
               database = ms.relay_integration.tune_relay(
                   mod=converted_mod,
                   target=target,
                   work_dir=work_dir,
                   max_trials_global=3000,
                   params=params,
               )
               rt_mod1 = ms.relay_integration.compile_relay(
                   database=database,
                   mod=converted_mod,
                   target=target,
                   params=params,
               )
           print(profiler.table())
   
           # Compile without MetaSchedule for correctness check
           with tvm.transform.PassContext(opt_level=0):
               rt_mod2 = relay.build(mod, target=target, params=params)
   
           def get_output(data, lib):
               module = 
tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
               module.set_input(input_name, data)
               module.run()
               return module.get_output(0).numpy()
   
           # Check correctness
           actual_output = get_output(data, rt_mod1)
           print("actual_output = ", actual_output)
           expected_output = get_output(data, rt_mod2)
           print("expected_output = ", expected_output)
   
           assert np.allclose(actual_output, expected_output, rtol=1e-2, 
atol=2e-2)
   
   if __name__ == "__main__":
       model_name = "bert_base"
       input_shape = (8, 128)
       test_cuda_tensor_core(model_name, input_shape)
       
   
   ```
   
   
   ### Triage
   * needs-triage
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to