KJlaccHoeUM9l opened a new issue, #11309:
URL: https://github.com/apache/tvm/issues/11309

   While experimenting with the proposed workaround for cooling the GPU (this 
[issue](https://github.com/apache/tvm/issues/11307)), a problem was found with 
invalid memory access when using `VirtualMachine`.
   
   The following error occurs:
   ```Shell
   Check failed: (e == cudaSuccess || e == cudaErrorCudartUnloading) is false: 
CUDA: invalid argument
   ```
   
   The error occurs in the following scenario:
   
   1. The first inference is over;
   2. Reset GPU;
   3. Launching the second inference;
   4. Copy output tensor from GPU to CPU;
   5. Occurrence of an error.
   
   The problem with accessing invalid memory after a GPU reset is related to 
the Virtual Machine. When using Graph Executor, everything works correctly. The 
problem arises when we try to call method `get_output` (for VM and GE they are 
arranged differently). 
   
   ### Expected behavior
   
   After resetting the GPU, we can create a new `VirtualMachine` object and run 
inference. The inference completes without error.
   
   ### Actual behavior
   
   A runtime error occurs:
   ```Shell
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     30: 0xffffffffffffffff
     29: _start
     28: __libc_start_main
     27: Py_BytesMain
     26: Py_RunMain
     25: PyRun_SimpleFileExFlags
     24: 0x000000000067e470
     23: 0x000000000067e3ce
     22: 0x000000000067e350
     21: PyEval_EvalCode
     20: _PyEval_EvalCodeWithName
     19: _PyEval_EvalFrameDefault
     18: _PyFunction_Vectorcall
     17: _PyEval_EvalFrameDefault
     16: _PyFunction_Vectorcall
     15: _PyEval_EvalCodeWithName
     14: _PyEval_EvalFrameDefault
     13: _PyFunction_Vectorcall
     12: _PyEval_EvalCodeWithName
     11: _PyEval_EvalFrameDefault
     10: _PyFunction_Vectorcall
     9: _PyEval_EvalFrameDefault
     8: _PyObject_MakeTpCall
     7: 0x00007fbb3ac0a9db
     6: _ctypes_callproc
     5: 0x00007fbb3abf1409
     4: 0x00007fbb3abf1ff4
     3: TVMArrayCopyToBytes
     2: tvm::runtime::ArrayCopyToBytes(DLTensor const*, void*, unsigned long)
     1: tvm::runtime::DeviceAPI::CopyDataFromTo(DLTensor*, DLTensor*, void*)
     0: tvm::runtime::CUDADeviceAPI::CopyDataFromTo(void const*, unsigned long, 
void*, unsigned long, unsigned long, DLDevice, DLDevice, DLDataType, void*)
     File "<tvm_root>/src/runtime/cuda/cuda_device_api.cc", line 234
   TVMError: 
   ---------------------------------------------------------------
   An error occurred during the execution of TVM.
   For more information, please see: https://tvm.apache.org/docs/errors.html
   ---------------------------------------------------------------
     Check failed: (e == cudaSuccess || e == cudaErrorCudartUnloading) is 
false: CUDA: invalid argument
   ```
   
   ### Environment
   Key | Value
   -- | --
   GPU: | Tesla T4 (16 GB)
   CPU: | Intel(R) Xeon(R) CPU @ 2.00GHz
   System: | Ubuntu 20.04.3 LTS
   Target: | x86_64-linux-gnu
   CUDA: | 11.1
   LLVM: | 12
   
   TVM with the following changes:
   ```Shell
   diff --git a/src/runtime/cuda/cuda_device_api.cc 
b/src/runtime/cuda/cuda_device_api.cc
   index b4d7b41b7..a60d4ece4 100644
   --- a/src/runtime/cuda/cuda_device_api.cc
   +++ b/src/runtime/cuda/cuda_device_api.cc
   @@ -252,6 +252,11 @@ 
TVM_REGISTER_GLOBAL("device_api.cuda_host").set_body([](TVMArgs args, TVMRetValu
      *rv = static_cast<void*>(ptr);
    });
    
   +TVM_REGISTER_GLOBAL("device_api.cuda_reset").set_body_typed([](int 
device_id) {
   +  CUDA_CALL(cudaSetDevice(device_id));
   +  CUDA_CALL(cudaDeviceReset());
   +});
   +
    class GPUTimerNode : public TimerNode {
     public:
      virtual void Start() {
   ```
   
   ### Steps to reproduce
   
   ```Python
   import copy
   import tqdm
   import numpy as np
   from functools import partial
   from onnx import helper, checker, mapping
   
   import tvm
   from tvm import relay
   from tvm.relay import vm
   from tvm.contrib import graph_executor
   
   
   def get_two_input_model(op_name):
       in_shape = [1, 2, 3, 3]
       in_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype("float32")]
       out_shape = in_shape
       out_type = in_type
   
       layer = helper.make_node(op_name, ["in1", "in2"], ["out"])
       graph = helper.make_graph(
           [layer],
           "two_input_test",
           inputs=[
               helper.make_tensor_value_info("in1", in_type, in_shape),
               helper.make_tensor_value_info("in2", in_type, in_shape),
           ],
           outputs=[
               helper.make_tensor_value_info(
                   "out", out_type, out_shape
               )
           ],
       )
       model = helper.make_model(graph, producer_name="two_input_test")
       checker.check_model(model, full_check=True)
       return model
   
   
   def generate_input_dict():
       input_dict = {}
       input_info = [
           {'inputName': 'in1', 'inputDtype': 'float32', 'inputShape': [1, 2, 
3, 3]},
           {'inputName': 'in2', 'inputDtype': 'float32', 'inputShape': [1, 2, 
3, 3]},
       ]
       for i in input_info:
           input_name = i["inputName"]
           input_shape = i["inputShape"]
           input_dtype = i["inputDtype"]
           input_dict[input_name] = 
tvm.nd.array(np.random.uniform(size=input_shape).astype(input_dtype), 
tvm.cuda(0))
   
       return input_dict
   
   
   def compile(model,
               executor,
               target,
               target_host,
               opt_level,
               opset,
               freeze_params,
               ):
       def get_tvm_executor(irmod, executor, target, target_host, params):
           if executor == "vm":
               lib = vm.compile(
                   copy.deepcopy(irmod),
                   target,
                   params=params,
                   target_host=target_host,
               )
           elif executor == "graph":
               lib = relay.build(irmod, target=target, target_host=target_host, 
params=params)
           else:
               print("ERROR: Executor type {} is unsupported. 
".format(executor),
                     "Only \"vm\" and \"graph\" types are supported")
               return None
           return lib
   
       irmod, params = relay.frontend.from_onnx(model, opset=opset, 
freeze_params=freeze_params)
   
       with tvm.transform.PassContext(opt_level=opt_level):
           lib = get_tvm_executor(irmod, executor, target, target_host, params)
   
       # Build module
       ctx = tvm.device(target, 0)
       if executor == "vm":
           m = tvm.runtime.vm.VirtualMachine(lib, ctx)
       elif executor == "graph":
           m = graph_executor.GraphModule(lib["default"](ctx))
       mod = m.module
   
       # Set input
       tvm_inputs = generate_input_dict()
       if executor == "vm":
           set_input = mod.get_function('set_one_input')
           for inp_name, inp in tvm_inputs.items():
               set_input("main", inp_name, inp)
       elif executor == "graph":
           set_input = mod.get_function('set_input')
           for inp_name, inp in tvm_inputs.items():
               set_input(inp_name, inp)
   
       # Run
       if executor == "vm":
           run = partial(mod.get_function("invoke"), "main")
       else:
           run = mod.get_function("run")
   
       if executor == "vm":
           output = run()
           if isinstance(output, tvm.nd.NDArray):
               output = [output]
           return [tvm_nd_array.numpy() for tvm_nd_array in output]
       else:
           run()
           output = [mod.get_function("get_output")(output_index).numpy() for 
output_index in range(mod.get_function("get_num_outputs")())]
           return output
   
   
   def main(executor):
       onnx_model = get_two_input_model("Add")
   
       compile_options = dict(
           target="cuda",
           target_host="llvm -mtriple=x86_64-linux-gnu",
           opt_level=3,
           opset_version=onnx_model.opset_import[0].version,
           freeze_weights=True,
       )
   
       compile(
           onnx_model,
           executor,
           compile_options["target"],
           compile_options["target_host"],
           compile_options["opt_level"],
           compile_options["opset_version"],
           compile_options["freeze_weights"],
       )
   
   
   if __name__ == "__main__":
       num_runs = 3
       reset_gpu = tvm.get_global_func("device_api.cuda_reset")
   
       print("\n\n********************* Compiling with Graph Executor 
*********************")
       for _ in tqdm.tqdm(range(num_runs)):
           main("graph")
           reset_gpu(0)
   
       print("\n\n********************* Compiling with Virtual Machine 
*********************")
       for _ in tqdm.tqdm(range(num_runs)):
           main("vm")
           reset_gpu(0)
   ```
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to