Kathryn-cat opened a new pull request, #18027:
URL: https://github.com/apache/tvm/pull/18027
This PR focuses on supporting FP4/FP8 data types introduced in Blackwell
architectures (sm_100).
TVM nd array stores subbyte data types in compact format, thus two FP4 would
be stored in 1 byte. The size calculator for array allocator is modified
accordingly.
---
**Subtype arithmetic**
The type `__nv_fp4_e2m1` from `<cuda_fp4.h>` is a tag type and does not
support pointer arithmetic. Accordingly, the compiler does not support index
operations on an array declared with `__nv_fp4_e2m1` directly. If any index
operations like `arr[0] + arr[1]` is desired, user should declare the array as
vector type like `__nv_fp4x2_e2m1`.
For example, suppose user creates a TVM array of type `__nv_fp4_e2m1` with
values
[-1 2 0.5 -6 -6 -2 2 3 4 1 -3 4 -2 2...]
```
extern "C" __global__ void __launch_bounds__(32) add_kernel(__nv_fp4_e2m1*
__restrict__ A, __nv_fp4_e2m1* __restrict__ C) {
C[((int)threadIdx.x)] = (__nv_fp4_e2m1)(((half)A[((int)threadIdx.x)]) +
((half)B[((int)threadIdx.x)]));
}
```
Printing out values of A[0], A[1], ... will show
```
A[0]: 2.000000
A[1]: -6.000000
A[2]: -2.000000
A[3]: 3.000000
```
This is because `__nv_fp4_e2m1` is only a tag type. When it advances
pointer, it advance by 1-byte at a time, yielding the upper 4 bits in the
packed memory buffer. As a result, we should avoid directly doing indexing on
`__nv_fp4_e2m1` for arithmetic operations.
If user passes in `__nv_fp4_e2m1` nd array and perform indexing, we can
convert it to `__nv_fp4x2_e2m1` and recalculate the indexing if possible, but
this requires more careful handling in the lowering process.
Thus, the original corresponding test case in
`test_target_codegen_cuda_fp4.py` is removed.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]