yzh119 opened a new pull request, #15190: URL: https://github.com/apache/tvm/pull/15190
# Motivation Currently, our CUDA codegen would not utilize CUDA's [half2](https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____HALF2__ARITHMETIC.html#group__CUDA__MATH____HALF2__ARITHMETIC) and [nv_bfloat162](https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT162__ARITHMETIC.html#group__CUDA__MATH____BFLOAT162__ARITHMETIC) intrinsics, and calls the scalar operators for each elements in the vector, which is not efficient. This PR improves the CUDA code by emitting half2 and nv_bfloat162 intrinsics when possible, and could potentially makes the generated program running faster (in case that nvcc didn't do this optimization for some programs). The PR is based on #15183 and will be rebased to mainline after that PR get merged. # Example Suppose a user is vectorizing the following operation: ```python import tvm import tvm.tir as tir from tvm.script import tir as T @T.prim_func def vec_fp16(a: T.Buffer((128,), "float16"), b: T.Buffer((128,), "float16")): for i in range(128): with T.block("b"): vi = T.axis.spatial(128, i) b[vi] = a[vi] * T.float16(3.0) + T.float16(1.0) sch = tir.Schedule(vec_fp16) b = sch.get_block("b") i = sch.get_loops(b)[0] bx, tx, vec = sch.split(i, [2, 32, 2]) sch.bind(bx, "blockIdx.x") sch.bind(tx, "threadIdx.x") sch.vectorize(vec) f = tvm.build(sch.mod["main"], target="cuda") print(f.imported_modules[0].get_source()) ``` Before this PR, TVM would emit the following CUDA code: ```cuda extern "C" __global__ void __launch_bounds__(32) default_function_kernel(half* __restrict__ a, half* __restrict__ b) { uint1 __1; uint1 __2; uint1 v_ = *(uint1*)(a + ((((int)blockIdx.x) * 64) + (((int)threadIdx.x) * 2))); uint1 v__1 = make_uint1(__pack_half2(__float2half_rn(3.000000e+00f), __float2half_rn(3.000000e+00f))); ((half2*)(&(__2.x)))->x = (((half2*)(&(v_.x)))->x*((half2*)(&(v__1.x)))->x); ((half2*)(&(__2.x)))->y = (((half2*)(&(v_.x)))->y*((half2*)(&(v__1.x)))->y); uint1 v__2 = make_uint1(__pack_half2(__float2half_rn(1.000000e+00f), __float2half_rn(1.000000e+00f))); ((half2*)(&(__1.x)))->x = (((half2*)(&(__2.x)))->x+((half2*)(&(v__2.x)))->x); ((half2*)(&(__1.x)))->y = (((half2*)(&(__2.x)))->y+((half2*)(&(v__2.x)))->y); *(uint1*)(b + ((((int)blockIdx.x) * 64) + (((int)threadIdx.x) * 2))) = __1; } ``` After this PR, TVM would emit code that uses half2 instrinsics directly: ```python extern "C" __global__ void __launch_bounds__(32) default_function_kernel(half* __restrict__ a, half* __restrict__ b) { *(half2*)(b + ((((int)blockIdx.x) * 64) + (((int)threadIdx.x) * 2))) = ((*(half2*)(a + ((((int)blockIdx.x) * 64) + (((int)threadIdx.x) * 2))) * make_half2(__float2half_rn(3.000000e+00f), __float2half_rn(3.000000e+00f))) + make_half2(__float2half_rn(1.000000e+00f), __float2half_rn(1.000000e+00f))); } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
