hxzd5568 commented on issue #6588:
URL: https://github.com/apache/tvm/issues/6588#issuecomment-1994549259
@tqchen @xiaoxia42, It seems an imprecision issue, rather than a logic error
in the transformation.
Because when I used the same configurations except enhancing the
imprecision, the assertion error disappeared.
And from the cuda code, the explanation for the issue can be that when
splitting to different groups and accumulating their sum, truncation errors
rise up and influence on the results.
```
import numpy as np
import tvm
from tvm import relay
import tvm.testing
from onnxutils import MSE
@tvm.testing.uses_gpu
def test_dense():
# ------------ change to float64
for dtype in ["float64"]:
x = relay.var("x", shape=(8, 512), dtype=dtype)
w = relay.var("w", shape=(128, 512), dtype=dtype)
z = relay.nn.dense(x, w)
# Check result.
func = relay.Function([x, w], z)
x_data = np.random.randn(8, 512).astype(dtype)
w_data = np.random.randn(128, 512).astype(dtype)
ref_res = np.dot(x_data, w_data.T)
target = tvm.target.cuda()
ctx = tvm.cuda(0)
intrp1 = relay.create_executor("graph", device=ctx, target=target)
intrp2 = relay.create_executor("debug", device=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x_data, w_data)
# tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
op_res2 = intrp2.evaluate(func)(x_data, w_data)
# tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
print(MSE(op_res2.asnumpy(), ref_res))
# ----------error disappear, and the mean relative error is around
1e-15 -----
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-13)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(func, target='cuda')
lib.export_library("compiled_model.tar")
print(lib.imported_modules[0].get_source())
if __name__ == "__main__":
test_dense()
```
The cuda code is as follows. The imprecision is probably caused by the
changes in compuation order. The new computation splits original ones to 8
groups and then accumulates their results.
```
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)
#define __shfl_sync(mask, var, lane, width) \
__shfl((var), (lane), (width))
#define __shfl_down_sync(mask, var, offset, width) \
__shfl_down((var), (offset), (width))
#define __shfl_up_sync(mask, var, offset, width) \
__shfl_up((var), (offset), (width))
#endif
#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \
(__CUDACC_VER_MAJOR__ > 11))
#define TVM_ENABLE_L2_PREFETCH 1
#else
#define TVM_ENABLE_L2_PREFETCH 0
#endif
#ifdef _WIN32
using uint = unsigned int;
using uchar = unsigned char;
using ushort = unsigned short;
using int64_t = long long;
using uint64_t = unsigned long long;
#else
#define uint unsigned int
#define uchar unsigned char
#define ushort unsigned short
#define int64_t long long
#define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(64)
tvmgen_default_fused_nn_dense_kernel(float* __restrict__ T_matmul_NT, float*
__restrict__ p0, float* __restrict__ p1);
extern "C" __global__ void __launch_bounds__(64)
tvmgen_default_fused_nn_dense_kernel(float* __restrict__ T_matmul_NT, float*
__restrict__ p0, float* __restrict__ p1) {
float T_matmul_NT_rf[1];
__shared__ float red_result[1];
T_matmul_NT_rf[0] = 0.000000e+00f;
for (int k_outer = 0; k_outer < 8; ++k_outer) {
T_matmul_NT_rf[0] = (T_matmul_NT_rf[0] + (p0[(((((int)blockIdx.y) * 512)
+ (k_outer * 64)) + ((int)threadIdx.x))] * p1[(((((int)blockIdx.x) * 512) +
(k_outer * 64)) + ((int)threadIdx.x))]));
}
float red_buf0[1];
uint mask[1];
float t0[1];
float red_buf0_1[1];
uint mask_1[1];
float t0_1[1];
__shared__ float red_buf_staging[2];
red_buf0_1[0] = T_matmul_NT_rf[0];
mask_1[0] = __activemask();
t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 16, 32);
red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 8, 32);
red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 4, 32);
red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 2, 32);
red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
t0_1[0] = __shfl_down_sync(mask_1[0], red_buf0_1[0], 1, 32);
red_buf0_1[0] = (red_buf0_1[0] + t0_1[0]);
if ((((int)threadIdx.x) % 32) == 0) {
red_buf_staging[(((int)threadIdx.x) >> 5)] = red_buf0_1[0];
}
__syncthreads();
if (((int)threadIdx.x) < 2) {
red_buf0[0] = red_buf_staging[((int)threadIdx.x)];
}
mask[0] = (__activemask() & (uint)3);
t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 1, 32);
red_buf0[0] = (red_buf0[0] + t0[0]);
if (((int)threadIdx.x) == 0) {
((volatile float*)red_result)[0] = red_buf0[0];
}
__syncthreads();
if (((int)threadIdx.x) == 0) {
T_matmul_NT[((((int)blockIdx.y) * 128) + ((int)blockIdx.x))] =
((volatile float*)red_result)[0];
}
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]