coffezhou opened a new issue, #17963:
URL: https://github.com/apache/tvm/issues/17963
### Expected behavior
TVM should compile the model correctly.
### Actual behavior
When compiling the model with the CUDA backend, TVM crashes as follows:
```c
Traceback (most recent call last):
File
"/home/carla/Documents/test_tvm/test-tvm-llm/0425/bugs/onnx_output0/test.py",
line 44, in main
ex = tvm.compile(tvm_model, target="cuda")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/carla/Documents/tvm/python/tvm/driver/build_module.py", line
104, in compile
return tvm.relax.build(
^^^^^^^^^^^^^^^^
File "/home/carla/Documents/tvm/python/tvm/relax/vm_build.py", line 259,
in build
return _vmlink(
^^^^^^^^
File "/home/carla/Documents/tvm/python/tvm/relax/vm_build.py", line 154,
in _vmlink
lib = tvm.tir.build(tir_mod, target=target, pipeline=tir_pipeline)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/carla/Documents/tvm/python/tvm/tir/build.py", line 186, in
build
return tir_to_runtime(host_mod, device_mod_dict, target_host)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/carla/Documents/tvm/python/tvm/tir/build.py", line 96, in
tir_to_runtime
device_modules.append(codegen_build(device_mod, target))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/carla/Documents/tvm/python/tvm/tir/build.py", line 80, in
codegen_build
return bf(mod, target)
^^^^^^^^^^^^^^^
File "tvm/ffi/cython/./function.pxi", line 212, in
tvm.ffi.core.Function.__call__
File "tvm/ffi/cython/./function.pxi", line 265, in
tvm.ffi.core.tvm_ffi_callback
File "/home/carla/Documents/tvm/python/tvm/contrib/nvcc.py", line 204, in
tvm_callback_cuda_compile
ptx = compile_cuda(code, target_format="fatbin")
File "/home/carla/Documents/tvm/python/tvm/contrib/nvcc.py", line 128, in
compile_cuda
raise RuntimeError(msg)
RuntimeError: #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
#include <cuda_fp16.h>
__device__ half max(half a, half b)
{
return __hgt(__half(a), __half(b)) ? a : b;
}
__device__ half min(half a, half b)
{
return __hlt(__half(a), __half(b)) ? a : b;
}
#else
typedef unsigned short uint16_t;
typedef unsigned char uint8_t;
typedef signed char int8_t;
typedef int int32_t;
typedef unsigned long long uint64_t;
typedef unsigned int uint32_t;
#define TVM_FORCE_INLINE inline __attribute__((always_inline))
#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__
#define TVM_ALIGNED(x) __attribute__ ((aligned(x)))
#define TVM_HALF_OPERATOR(RTYPE, OP) \
TVM_XINLINE RTYPE operator OP (half a, half b) { \
return RTYPE(float(a) OP float(b)); \
} \
template<typename T> \
TVM_XINLINE RTYPE operator OP (half a, T b) { \
return RTYPE(float(a) OP float(b)); \
} \
template<typename T> \
TVM_XINLINE RTYPE operator OP (T a, half b) { \
return RTYPE(float(a) OP float(b)); \
}
#define TVM_HALF_ASSIGNOP(AOP, OP) \
template<typename T> \
TVM_XINLINE half operator AOP (const T& a) { \
return *this = half(float(*this) OP float(a)); \
} \
template<typename T> \
TVM_XINLINE half operator AOP (const volatile T& a) volatile { \
return *this = half(float(*this) OP float(a)); \
}
class TVM_ALIGNED(2) half {
public:
uint16_t half_;
static TVM_XINLINE half Binary(uint16_t value) {
half res;
res.half_ = value;
return res;
}
TVM_XINLINE half() {}
TVM_XINLINE half(const float& value) { constructor(value); }
TVM_XINLINE explicit half(const double& value) { constructor(value); }
TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }
TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
TVM_XINLINE explicit half(const long long& value) { constructor(value); }
TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }
TVM_XINLINE operator float() const { \
return float(half2float(half_)); \
} \
TVM_XINLINE operator float() const volatile { \
return float(half2float(half_)); \
}
TVM_HALF_ASSIGNOP(+=, +)
TVM_HALF_ASSIGNOP(-=, -)
TVM_HALF_ASSIGNOP(*=, *)
TVM_HALF_ASSIGNOP(/=, /)
TVM_XINLINE half operator+() {
return *this;
}
TVM_XINLINE half operator-() {
return half(-float(*this));
}
TVM_XINLINE half operator=(const half& a) {
half_ = a.half_;
return a;
}
template<typename T>
TVM_XINLINE half operator=(const T& a) {
return *this = half(a);
}
TVM_XINLINE half operator=(const half& a) volatile {
half_ = a.half_;
return a;
}
template<typename T>
TVM_XINLINE half operator=(const T& a) volatile {
return *this = half(a);
}
private:
union Bits {
float f;
int32_t si;
uint32_t ui;
};
static int const fp16FractionBits = 10;
static int const fp32FractionBits = 23;
static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); //
== 0x7fffff
static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // ==
0x800000
static int const shift = fp32FractionBits - fp16FractionBits; // == 13
static int const shiftSign = 16;
static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so
exp16 = exp32 - (127-15)
static int32_t const infN = 0x7F800000; // flt32 infinity
static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16
normal after >> by shift
static int32_t const minN = 0x38800000; // min flt16 normal as a flt32
static int32_t const maxZ = 0x33000000; // max fp32 number that's still
rounded to zero in fp16
static int32_t const signN = 0x80000000; // flt32 sign bit
static int32_t const infC = infN >> shift;
static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as
a flt32
static int32_t const maxC = maxN >> shift;
static int32_t const minC = minN >> shift;
static int32_t const signC = signN >> shiftSign; // flt16 sign bit
static int32_t const mulN = 0x52000000; // (1 << 23) / minN
static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift))
static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted
static int32_t const norC = 0x00400; // min flt32 normal down shifted
static int32_t const maxD = infC - maxC - 1;
static int32_t const minD = minC - subC - 1;
TVM_XINLINE uint16_t float2half(const float& value) const {
Bits v;
v.f = value;
uint32_t sign = v.si & signN; // grab sign bit
v.si ^= sign; // clear sign bit from v
sign >>= shiftSign; // logical shift sign to fp16 position
if (v.si <= maxZ) {
// Handle eventual zeros here to ensure
// vshift will not exceed 32 below.
v.ui = 0;
} else if (v.si < minN) {
// Handle denorms
uint32_t exp32 = v.ui >> fp32FractionBits;
int32_t exp16 = exp32 - expAdjust;
// If exp16 == 0 (just into the denorm range), then significant should
be shifted right 1.
// Smaller (so negative) exp16 values should result in greater right
shifts.
uint32_t vshift = 1 - exp16;
uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
v.ui = significand >> vshift;
v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 :
0;
} else if (v.si <= maxN) {
// Handle norms
v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
v.ui -= expAdjust << fp32FractionBits;
} else if (v.si <= infN) {
v.si = infN;
} else if (v.si < nanN) {
v.si = nanN;
}
v.ui >>= shift;
return sign | (v.ui & 0x7fff);
}
// Same as above routine, except for addition of volatile keyword
TVM_XINLINE uint16_t float2half(
const volatile float& value) const volatile {
Bits v;
v.f = value;
uint32_t sign = v.si & signN; // grab sign bit
v.si ^= sign; // clear sign bit from v
sign >>= shiftSign; // logical shift sign to fp16 position
if (v.si <= maxZ) {
// Handle eventual zeros here to ensure
// vshift will not exceed 32 below.
v.ui = 0;
} else if (v.si < minN) {
// Handle denorms
uint32_t exp32 = v.ui >> fp32FractionBits;
int32_t exp16 = exp32 - expAdjust;
// If exp16 == 0 (just into the denorm range), then significant should
be shifted right 1.
// Smaller (so negative) exp16 values should result in greater right
shifts.
uint32_t vshift = 1 - exp16;
uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
v.ui = significand >> vshift;
v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 :
0;
} else if (v.si <= maxN) {
// Handle norms
v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
v.ui -= expAdjust << fp32FractionBits;
} else if (v.si <= infN) {
v.si = infN;
} else if (v.si < nanN) {
v.si = nanN;
}
v.ui >>= shift;
return sign | (v.ui & 0x7fff);
}
TVM_XINLINE float half2float(const uint16_t& value) const {
Bits v;
v.ui = value;
int32_t sign = v.si & signC;
v.si ^= sign;
sign <<= shiftSign;
v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
Bits s;
s.si = mulC;
s.f *= v.si;
int32_t mask = -(norC > v.si);
v.si <<= shift;
v.si ^= (s.si ^ v.si) & mask;
v.si |= sign;
return v.f;
}
TVM_XINLINE float half2float(
const volatile uint16_t& value) const volatile {
Bits v;
v.ui = value;
int32_t sign = v.si & signC;
v.si ^= sign;
sign <<= shiftSign;
v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
Bits s;
s.si = mulC;
s.f *= v.si;
int32_t mask = -(norC > v.si);
v.si <<= shift;
v.si ^= (s.si ^ v.si) & mask;
v.si |= sign;
return v.f;
}
template<typename T>
TVM_XINLINE void constructor(const T& value) {
half_ = float2half(float(value));
}
};
TVM_HALF_OPERATOR(half, +)
TVM_HALF_OPERATOR(half, -)
TVM_HALF_OPERATOR(half, *)
TVM_HALF_OPERATOR(half, /)
TVM_HALF_OPERATOR(bool, >)
TVM_HALF_OPERATOR(bool, <)
TVM_HALF_OPERATOR(bool, >=)
TVM_HALF_OPERATOR(bool, <=)
TVM_XINLINE half __float2half_rn(const float a) {
return half(a);
}
#endif
#include <cuda.h>
// Pack two half values.
static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \
static inline __device__ __host__ half HALF_MATH_NAME(half x, half y) { \
float tmp_x = __half2float(x); \
float tmp_y = __half2float(y); \
float result = FP32_MATH_NAME(tmp_x, tmp_y); \
return __float2half(result); \
}
#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \
static inline __device__ __host__ half HALF_MATH_NAME(half x) { \
float tmp_x = __half2float(x); \
float result = FP32_MATH_NAME(tmp_x); \
return __float2half(result); \
}
// Some fp16 math functions are not supported in cuda_fp16.h,
// so we define them here to make sure the generated CUDA code
// is valid.
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 530)
CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)
#if ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) &&
(__CUDACC_VER_MINOR__ < 8)))
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)
#endif
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)
#else
CUDA_UNSUPPORTED_HALF_MATH_UNARY(hexp, exp)
#endif
#endif
#undef CUDA_UNSUPPORTED_HALF_MATH_BINARY
#undef CUDA_UNSUPPORTED_HALF_MATH_UNARY
#include <type_traits>
template <typename T, typename TVec2>
struct __align__(8) half4_bfloat164 {
T x, y, z, w;
__host__ __device__ half4_bfloat164() : x(T(0)), y(T(0)), z(T(0)), w(T(0))
{}
__host__ __device__ half4_bfloat164(T x, T y, T z, T w) : x(x), y(y),
z(z), w(w) {}
};
using half4 = half4_bfloat164<__half, __half2>;
__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w)
{
return half4(x, y, z, w);
}
#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \
(__CUDACC_VER_MAJOR__ > 11))
#define TVM_ENABLE_L2_PREFETCH 1
#else
#define TVM_ENABLE_L2_PREFETCH 0
#endif
#ifdef _WIN32
using uint = unsigned int;
using uchar = unsigned char;
using ushort = unsigned short;
using int64_t = long long;
using uint64_t = unsigned long long;
#else
#define uint unsigned int
#define uchar unsigned char
#define ushort unsigned short
#define int64_t long long
#define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(4) tir_sinh_kernel(half*
__restrict__ compute, half* __restrict__ lv3);
extern "C" __global__ void __launch_bounds__(4) tir_sinh_kernel(half*
__restrict__ compute, half* __restrict__ lv3) {
compute[((int)threadIdx.x)] = hsinh(lv3[((int)threadIdx.x)]);
}
Compilation error:
/tmp/tmpzws9vm8t/tvm_kernels.cu(354): error: identifier "hsinh" is undefined
/tmp/tmpzws9vm8t/tvm_kernels.cu(277): warning #177-D: function
"__pack_half2" was declared but never referenced
/tmp/tmpzws9vm8t/tvm_kernels.cu(303): warning #177-D: function "hpow" was
declared but never referenced
/tmp/tmpzws9vm8t/tvm_kernels.cu(305): warning #177-D: function "htanh" was
declared but never referenced
/tmp/tmpzws9vm8t/tvm_kernels.cu(307): warning #177-D: function "htan" was
declared but never referenced
/tmp/tmpzws9vm8t/tvm_kernels.cu(308): warning #177-D: function "hatan" was
declared but never referenced
/tmp/tmpzws9vm8t/tvm_kernels.cu(309): warning #177-D: function "herf" was
declared but never referenced
1 error detected in the compilation of "/tmp/tmpzws9vm8t/tvm_kernels.cu".
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File
"/home/carla/Documents/test_tvm/test-tvm-llm/0425/bugs/onnx_output0/test.py",
line 48, in <module>
main()
^^^^^^
File
"/home/carla/Documents/test_tvm/test-tvm-llm/0425/bugs/onnx_output0/test.py",
line 40, in main
with tvm.target.Target("cuda"):
File "/home/carla/Documents/tvm/python/tvm/target/target.py", line 145, in
__exit__
_ffi_api.TargetExitScope(self)
File "tvm/ffi/cython/./function.pxi", line 212, in
tvm.ffi.core.Function.__call__
tvm.error.InternalError: Check failed:
(entry->context_stack.top().same_as(*this)) is false:
```
### Environment
OS: Ubuntu 20.04
TVM: 0.21.dev0(bcb68b130)
CUDA: 11.8
### Steps to reproduce
This bug can be reproduced by the following code with the model in the
attachment. As shown in the code, the model can be executed by onnxruntime.
However, tvm failed to compile this model with CUDA backend.
```python
import sys
import numpy as np
import onnx
import onnxruntime
import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
import pickle
def main():
onnx_model = onnx.load("a719.onnx")
with open("inputs.pkl", "rb") as fp:
inputs = pickle.load(fp)
try:
ort_session = onnxruntime.InferenceSession(
onnx_model.SerializeToString(),
providers=["CPUExecutionProvider"]
)
ort_output = ort_session.run([], inputs)
except Exception as e:
print(e)
sys.exit(1)
print(ort_output)
# Convert the onnx model into relax through the onnx importer.
tvm_model = from_onnx(onnx_model, keep_params_in_input=True)
# Convert operators for inference mode.
tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
# Legalize any relax ops into tensorir.
tvm_model = relax.transform.LegalizeOps()(tvm_model)
# Separate model from parameters.
tvm_model, params = relax.frontend.detach_params(tvm_model)
#----------------------cuda-----------------------
with tvm.target.Target("cuda"):
tvm_model = tvm.tir.transform.DefaultGPUSchedule()(tvm_model)
with tvm.transform.PassContext(opt_level=3):
ex = tvm.compile(tvm_model, target="cuda")
vm1 = relax.VirtualMachine(ex, tvm.cuda())
if __name__ == "__main__":
main()
```
[testcase.zip](https://github.com/user-attachments/files/20179895/testcase.zip)
### Triage
* needs-triage
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]