roastduck commented on a change in pull request #5382:
URL: https://github.com/apache/incubator-tvm/pull/5382#discussion_r421205351



##########
File path: src/te/operation/op_util.cc
##########
@@ -164,9 +164,21 @@ MakeLoopNest(const Stage& stage,
         value_map[iv] = dom->min;
       } else {
         runtime::ThreadScope ts = 
runtime::ThreadScope::make(bind_iv->thread_tag);
-        if (stage->scope == "" || stage->scope == "warp" ||
+        if (stage->scope == "" ||
             static_cast<int>(runtime::StorageScope::make(stage->scope).rank) 
<= ts.rank) {
           value_map[iv] = var;
+        } else if (stage->scope == "warp" && ts.rank == 1) {
+          // To determine whether a thread index is inside or outside a warp, 
we need
+          // to know the thread extent. We leave a warning for now.
+          if (ts.dim_index == 0) {
+            value_map[iv] = var;
+          } else {

Review comment:
       > However, it seems not the case, as your 3rd situation ends with 
incorrect code instead of an error from LowerWarpMemory(). But I don't know the 
reason.
   
   Actually I mean:
   
   1. The 1st situation has no problem, before and after this PR.
   2. The 2nd situation led to incorrect code before this PR, and correct code 
after this PR. Plus, we will see a warning after this PR.
   3. The 3rd situation is currently not supported by `lower_warp_memory`, 
which will lead to an error. Plus, we will see a warning after this PR.
   
   No matter how, I still have no idea how the 2nd situation ends up with 
incorrect code.
   
   > I see. Does your simplified case trigger the warning? If so, checking for 
the warning can guard your changes from being accidentally deleted or skipped.
   
   Actually no. Let's have a look at some details.
   
   Here's my simplified example:
   
   ```python
   import tvm
   import topi
   import numpy as np
   
   from tvm import te
   
   n = 32
   A = te.placeholder((n, n), name='A', dtype="float32")
   C = te.compute((n, n), lambda i, j: A(i, (j + 1) % n), name='C')
   
   s = te.create_schedule(C.op)
   th_y = te.thread_axis("threadIdx.y")
   th_x = te.thread_axis("threadIdx.x")
   B = s.cache_read(A, "warp", [C])
   ci, cj = C.op.axis
   bi, bj = B.op.axis
   s[C].bind(ci, th_y)
   s[C].bind(cj, th_x)
   s[B].compute_at(s[C], ci)
   s[B].bind(bj, th_x)
   
   print(tvm.lower(s, [A, C]))
   ```
   
   And here's the result, which is unexpectedly correct before this PR.
   
   ```
   PrimFunc([A, C]) attrs={"tir.noalias": (bool)1, "global_symbol": "main"} {
     // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 32
     // attr [A.warp] storage_scope = "warp"
     allocate A.warp[float32 * 32]
     // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32
     A.warp[threadIdx.x] = A[((threadIdx.y*32) + threadIdx.x)]
     // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32
     C[((threadIdx.y*32) + threadIdx.x)] = A.warp[floormod((threadIdx.x + 1), 
32)]
   }
   ```
   
   The `if (stage->scope == "warp" && ts.rank == 1)` branch in the modified 
code is only triggered once, where `ts.dim_index == 0`. I don't know why the 
`ts.dim_index == 1` `IterVar` is ignored.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to