Jiawei-Shao commented on issue #16627:
URL: https://github.com/apache/tvm/issues/16627#issuecomment-2009056083
Hi @tqchen,
I am stuck on the translation of `int8` to WGSL so I have to turn to you
for help.
Currently WGSL doesn't support 8-bit integers, so in the output WGSL we can
only load and store a `uint32`s as `int8x4`. What can I do to ask TVM to
generate TIR that always loads and stores 4 `int8` together instead of loading
or storing `int8` separately?
For example,
```Python
block_read_0 = sch.cache_read(block, 0, "shared")
```
will generate below TIR
```
X_shared_1 = T.Buffer((228,), "int8", data=X_shared, scope="shared")
for ax0, ax1 in T.grid(57, 4):
X_1 = T.Buffer((16384,), "int8", data=X.data)
X_shared_1[ax0 * 4 + ax1] = X_1[blockIdx_y * 8192 + i_2 * 128 + ax0 *
128 + k_0 * 4 + ax1]
```
And it will generate below WGSL (X is an `int8` array, which is incorrect in
WGSL):
```Python
for (var ax1 : i32 = 0; ax1 < 4i; ax1++) {
X_shared[((ax0 * 4i) + ax1)] = X[(((((i32(blockIdx.y) * 8192i) + (i_2 *
128i)) + (ax0 * 128i)) + (k_0 * 4i)) + ax1)];
}
```
While I'd like to have below WGSL (X_shared and X are both `uint32` arrays):
```Javascript
X_shared[ax0] = X[((i32(blockIdx.y) * 2048i) + (i_2 * 32i)) + (ax0 * 32i)) +
(k_0))];
```
The declarations of `X_shared` and `X` are easy to handle (for example, I
can change `var<workgroup> X_shared : array<i8, 228>` to `var<workgroup>
X_shared : array<u32, 57>`) while the loop in the load of `int8` data is
difficult to handle. Could you give me some advice on how to do it?
Here is the python script I used for test. I cannot upload `.py` files so I
just renamed it to `.txt`.
[tvm_dp4a.txt](https://github.com/apache/tvm/files/14663019/tvm_dp4a.txt)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]