roastduck commented on a change in pull request #5498:
URL: https://github.com/apache/incubator-tvm/pull/5498#discussion_r421221319
##########
File path: src/tir/transforms/lower_warp_memory.cc
##########
@@ -265,10 +265,11 @@ class WarpAccessRewriter : protected StmtExprMutator {
<< op->index << " local_index=" << local_index;
PrimExpr load_value = LoadNode::make(
op->dtype, op->buffer_var, local_index, op->predicate);
+ PrimExpr mask = IntImm(DataType::UInt(32), 0xFFFFFFFF);
Review comment:
Please have a look on this example (modified from a test for
`lower_warp_memory`):
```python
import tvm
import topi
import numpy as np
from tvm import te
A = te.placeholder((128,), name='A', dtype="float32")
B = te.compute((100,), lambda i: A[i // 32 * 32 + (i + 1) % 32], name='B')
cuda_target = tvm.target.create("cuda")
assert cuda_target.thread_warp_size == 32
with cuda_target:
s = te.create_schedule(B.op)
AA = s.cache_read(A, "warp", [B])
xo, xi = s[B].split(B.op.axis[0], 64)
xi0, xi1 = s[B].split(xi, factor=32)
tx = te.thread_axis("threadIdx.x")
s[B].bind(xi1, tx)
s[B].bind(xo, te.thread_axis("blockIdx.x"))
s[AA].compute_at(s[B], xo)
xo, xi = s[AA].split(s[AA].op.axis[0], 32)
s[AA].bind(xi, tx)
print(tvm.build(s, [A, B], "cuda").imported_modules[0].get_source())
```
The generated code is:
```cuda
extern "C" __global__ void default_function_kernel0(void* __restrict__ A,
void* __restrict__ B) {
float A_warp[2];
for (int ax0_outer = 0; ax0_outer < 2; ++ax0_outer) {
A_warp[(ax0_outer)] = ((float*)A)[((((((int)blockIdx.x) * 64) +
(ax0_outer * 32)) + ((int)threadIdx.x)))];
}
for (int i_inner_outer = 0; i_inner_outer < 2; ++i_inner_outer) {
if ((((((int)blockIdx.x) * 64) + (i_inner_outer * 32)) +
((int)threadIdx.x)) < 100) {
((float*)B)[((((((int)blockIdx.x) * 64) + (i_inner_outer * 32)) +
((int)threadIdx.x)))] = __shfl(A_warp[(i_inner_outer)], ((((int)threadIdx.x) +
1) & 31), 32);
}
}
}
```
Here `__shfl` is inside an `if`. Please check that whether `__shfl_async`
can deal with this example.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]