Karbit opened a new issue, #13803:
URL: https://github.com/apache/tvm/issues/13803

   The code generated by the cuda backend for the following relay function 
crashes due to illegal memory access.
   
   ```
   def @main(%FunctionVar_01: Tensor[(1024, 256), float16], %FunctionVar_11: 
Tensor[(1024, 256), float16]) -> Tensor[(4096, 1, 64), float16] {
       %0 = reshape(%FunctionVar_01, newshape=[-1, 1, 256]);
       %1 = reshape(%FunctionVar_11, newshape=[1024, 1, 256]);
       %2 = add(%0, %1);
       %3 = split(%2, indices_or_sections=4, axis=2);
       %4 = %3.0;
       %5 = %3.1;
       %6 = %3.2;
       %7 = %3.3;
       %8 = expand_dims(%4, axis=0);
       %9 = expand_dims(%5, axis=0);
       %10 = expand_dims(%6, axis=0);
       %11 = expand_dims(%7, axis=0);
       %12 = (%8, %9, %10, %11);
       %13 = concatenate(%12);
       reshape(%13, newshape=[4096, 1, 64])
     } /* ty=fn (Tensor[(1024, 256), float16], Tensor[(1024, 256), float16]) -> 
Tensor[(4096, 1, 64), float16] */
   ```
   
   ### Expected behavior
   
   Expect to print the seconds to run the relay program.
   
   ### Actual behavior
   
   ```
   An error occurred during the execution of TVM.
   For more information, please see: https://tvm.apache.org/docs/errors.html
   ---------------------------------------------------------------
     Check failed: (e == cudaSuccess || e == cudaErrorCudartUnloading) is 
false: CUDA: an illegal memory access was encountered
   terminate called after throwing an instance of 'tvm::runtime::InternalError'
     what():  [17:31:24] 
/data/project/liujiaqiang/workdir/tvm/src/runtime/cuda/cuda_module.cc:61: 
CUDAError: cuModuleUnload(module_[i]) failed with error: 
CUDA_ERROR_ILLEGAL_ADDRESS
   Stack trace:
     0: tvm::runtime::CUDAModuleNode::~CUDAModuleNode()
     1: _ZN3tvm7runtime18SimpleObjAllocator7H
     2: 
tvm::runtime::SimpleObjAllocator::Handler<tvm::runtime::LibraryModuleNode>::Deleter_(tvm::runtime::Object*)
     3: _ZN3tvm7runtime18SimpleObjAllocator7HandlerINS0_16Pac
     4: tvm::runtime::vm::VirtualMachine::~VirtualMachine()
     5: _ZN3tvm7runtime18SimpleObjAllocator7H
     6: _ZN3tvm7runtime18SimpleObjAllocator7HandlerINS0_16Pac
     7: TVMObjectFree
     8: ffi_call_unix64
     9: ffi_call
     10: _call_function_pointer
           at 
/media/ssd1/fordata/web_server/project/liujiaqiang/Python-3.7.9/Modules/_ctypes/callproc.c:829
     11: _ctypes_callproc
           at 
/media/ssd1/fordata/web_server/project/liujiaqiang/Python-3.7.9/Modules/_ctypes/callproc.c:1201
     12: PyCFuncPtr_call
           at 
/media/ssd1/fordata/web_server/project/liujiaqiang/Python-3.7.9/Modules/_ctypes/_ctypes.c:4025
     13: _PyObject_FastCallKeywords
           at Objects/call.c:199
     14: call_function
           at Python/ceval.c:4619
     15: _PyEval_EvalFrameDefault
   ```
   
   ### Environment
   
   The operation system is centos7, the cuda version is 11.4, the tvm commit id 
is 53824d697a.
   
   ### Steps to reproduce
   
   Put the relay function to a text file test.rly, and use the following script 
to reproduce the bug.
   
   ```python
   import tvm 
   from tvm.ir import IRModule, TypeCall
   from tvm.tir import Any
   
   import argparse
   import logging
   import os
   import math
   import tempfile
   
   import numpy as np
   import tvm
   import tvm.relay
   
   parser = argparse.ArgumentParser(description="Load from TVMScript module and 
compile")
   parser.add_argument("--tvm_script", type=str, default='test.rly', help="the 
tvm script to load")
   
   args = parser.parse_args()
   
   # Parameters to use when estimating latency (of both partitions and overall 
models).
   MEASURE_NUMBER = 20
   MEASURE_REPEAT = 5
   WARMUP_MIN_REPEAT_MS = 250
   
   def vm_estimate_seconds(device, the_vm, func_name, args):
       """Returns the estimated latency, in seconds, of running func_name with 
args on the_vm."""
       # Warmup
       the_vm.benchmark(
           device, repeat=1, number=1, min_repeat_ms=WARMUP_MIN_REPEAT_MS, 
func_name=func_name, **args
       )
       # One more time, with feeling
       return the_vm.benchmark(
           device,
           repeat=MEASURE_REPEAT,
           number=MEASURE_NUMBER,
           min_repeat_ms=0,
           func_name=func_name,
           **args,
       )
   
   def arg_for(arg_type, device):
       """Returns a test argument of Relay arg_type on device"""
       assert isinstance(arg_type, tvm.ir.TensorType)
       return tvm.nd.array(
           np.random.uniform(-1.0, 1.0, 
size=arg_type.concrete_shape).astype(arg_type.dtype),
           device=device,
       )
   
   def estimate_seconds(mod, target):
       """Returns the mean execution time of "main" in mod on target with 
params. The module
       may contain "Primitive" functions, possibly with "Compiler" 
attributes."""
       device = tvm.device(target.get_target_device_type())
       try:
           # Build the module.
           logging.info("Compiling module to estimate\n {}".format(mod))
           #print(mod.script())
           exe = tvm.relay.vm.compile(mod, target)
       except RuntimeError as err:
           # A build failure indicates the partition is not supported.
           # eg trying to build an nn.batch_norm on GPU, which has no schedule 
since we assume it
           # is only ever used with a tuple projection which is rewritten away.
           logging.info("Assigning module infinite cost since unable to build: 
%s", err)
           return math.inf
   
       # Finalize compilation
       tmp_dir = tempfile.mkdtemp()
       code, lib = exe.save()
       lib_path = os.path.join(tmp_dir, "library.so")
       # TODO(mbs): Avoid nvcc dependency?
       lib.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc")
       lib = tvm.runtime.load_module(lib_path)
       exe = tvm.runtime.vm.Executable.load_exec(code, lib)
   
       # Benchmark the module.
       the_vm = tvm.runtime.vm.VirtualMachine(exe, device)
       func_name = "main"
       main_args = {v.name_hint: arg_for(v.checked_type, device) for v in 
mod[func_name].params}
       logging.info("Benchmarking module to estimate")
       profile = vm_estimate_seconds(device, the_vm, func_name, main_args)
       logging.info("profile: %s", profile)
       return profile.median  # seconds
   
   mod = IRModule()
   mod._import(args.tvm_script)
   
   print("imported module:", mod)
   
   target = "cuda -libs=cudnn,cublas"
   target_host = None
   target = tvm.target.Target(target, target_host)
   #lib = tvm.relay.build(mod, target=target)
   #print(lib)
   print(estimate_seconds(mod, target))
   ```
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to