roastduck commented on a change in pull request #5382:
URL: https://github.com/apache/incubator-tvm/pull/5382#discussion_r420187744



##########
File path: src/te/operation/op_util.cc
##########
@@ -164,9 +164,21 @@ MakeLoopNest(const Stage& stage,
         value_map[iv] = dom->min;
       } else {
         runtime::ThreadScope ts = 
runtime::ThreadScope::make(bind_iv->thread_tag);
-        if (stage->scope == "" || stage->scope == "warp" ||
+        if (stage->scope == "" ||
             static_cast<int>(runtime::StorageScope::make(stage->scope).rank) 
<= ts.rank) {
           value_map[iv] = var;
+        } else if (stage->scope == "warp" && ts.rank == 1) {
+          // To determine whether a thread index is inside or outside a warp, 
we need
+          // to know the thread extent. We leave a warning for now.
+          if (ts.dim_index == 0) {
+            value_map[iv] = var;
+          } else {

Review comment:
       Let me explain this piece of code:
   
   Line 167: If we are using a local storage, then each iteration variable `iv` 
accounts to the accessing offset. For example, when accessing a local buffer 
`a[i, j]`, `i` and `j` are both offsets of `a`.
   
   Line 168: There are two cases:
   
   1. If we are using a global storage, then all indices account to the offset. 
For example, `b[blockIdx.x, threadIdx.x, i]`.
   2. If we are using a shared storage, then **variables inside a block**, i.e. 
iteration variables and thread indices (`<= ts.rank`) account to the offset, 
but block indices (`> ts.rank`) do not. For example, when accessing a shared 
buffer `c`, it should be `c[threadIdx.x, i]`, instead of `c[blockIdx.x, 
threadIdx.x, i]`.
   
   Now we come to warp storage. Ideally, all the variables **inside a warp** 
should account to the offset, but those **outside a warp** do not. It is 
complex to determine whether a variable is inside a warp, but usually only 
`threadIdx.x` is inside, so we made the assumption, and give a warning 
otherwise.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to