roastduck commented on a change in pull request #5382:
URL: https://github.com/apache/incubator-tvm/pull/5382#discussion_r421205351
##########
File path: src/te/operation/op_util.cc
##########
@@ -164,9 +164,21 @@ MakeLoopNest(const Stage& stage,
value_map[iv] = dom->min;
} else {
runtime::ThreadScope ts =
runtime::ThreadScope::make(bind_iv->thread_tag);
- if (stage->scope == "" || stage->scope == "warp" ||
+ if (stage->scope == "" ||
static_cast<int>(runtime::StorageScope::make(stage->scope).rank)
<= ts.rank) {
value_map[iv] = var;
+ } else if (stage->scope == "warp" && ts.rank == 1) {
+ // To determine whether a thread index is inside or outside a warp,
we need
+ // to know the thread extent. We leave a warning for now.
+ if (ts.dim_index == 0) {
+ value_map[iv] = var;
+ } else {
Review comment:
> However, it seems not the case, as your 3rd situation ends with
incorrect code instead of an error from LowerWarpMemory(). But I don't know the
reason.
Actually I mean:
1. The 1st situation has no problem, before and after this PR.
2. The 2nd situation led to incorrect code before this PR, and correct code
after this PR. Plus, we will see a warning after this PR.
3. The 3rd situation is currently not supported by `lower_warp_memory`,
which will lead to an error. Plus, we will see a warning after this PR.
No matter how, I still have no idea how the 2nd situation ends up with
incorrect code.
> I see. Does your simplified case trigger the warning? If so, checking for
the warning can guard your changes from being accidentally deleted or skipped.
Actually no. Let's have a look at some details.
Here's my simplified example:
```python
import tvm
import topi
import numpy as np
from tvm import te
n = 32
A = te.placeholder((n, n), name='A', dtype="float32")
C = te.compute((n, n), lambda i, j: A(i, (j + 1) % n), name='C')
s = te.create_schedule(C.op)
th_y = te.thread_axis("threadIdx.y")
th_x = te.thread_axis("threadIdx.x")
B = s.cache_read(A, "warp", [C])
ci, cj = C.op.axis
bi, bj = B.op.axis
s[C].bind(ci, th_y)
s[C].bind(cj, th_x)
s[B].compute_at(s[C], ci)
s[B].bind(bj, th_x)
print(tvm.lower(s, [A, C]))
```
And here's the result, which is unexpectedly correct before this PR.
```
PrimFunc([A, C]) attrs={"tir.noalias": (bool)1, "global_symbol": "main"} {
// attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 32
// attr [A.warp] storage_scope = "warp"
allocate A.warp[float32 * 32]
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32
A.warp[threadIdx.x] = A[((threadIdx.y*32) + threadIdx.x)]
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32
C[((threadIdx.y*32) + threadIdx.x)] = A.warp[floormod((threadIdx.x + 1),
32)]
}
```
The `if (stage->scope == "warp" && ts.rank == 1)` branch in the modified
code is only triggered once, where `ts.dim_index == 0`. I don't know why the
`ts.dim_index == 1` `IterVar` is ignored.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]