yongfeng-nv commented on a change in pull request #5382:
URL: https://github.com/apache/incubator-tvm/pull/5382#discussion_r421228975



##########
File path: src/te/operation/op_util.cc
##########
@@ -164,9 +164,21 @@ MakeLoopNest(const Stage& stage,
         value_map[iv] = dom->min;
       } else {
         runtime::ThreadScope ts = 
runtime::ThreadScope::make(bind_iv->thread_tag);
-        if (stage->scope == "" || stage->scope == "warp" ||
+        if (stage->scope == "" ||
             static_cast<int>(runtime::StorageScope::make(stage->scope).rank) 
<= ts.rank) {
           value_map[iv] = var;
+        } else if (stage->scope == "warp" && ts.rank == 1) {
+          // To determine whether a thread index is inside or outside a warp, 
we need
+          // to know the thread extent. We leave a warning for now.
+          if (ts.dim_index == 0) {
+            value_map[iv] = var;
+          } else {

Review comment:
       I modified the simplified example a little to bind threadIdx.y in the 
warp stage to let threadIdx.y pass through the new code.
   
       import tvm
       import topi
       import numpy as np
       
       from tvm import te
       
       n = 32
       A = te.placeholder((2, n, n), name='A', dtype="float32")
       C = te.compute((2, n, n), lambda x, i, j: A(x, i, (j + 1) % n), name='C')
       
       s = te.create_schedule(C.op)
       bk_x = te.thread_axis("blockIdx.x")
       th_y = te.thread_axis("threadIdx.y")
       th_x = te.thread_axis("threadIdx.x")
       B = s.cache_read(A, "warp", [C])
       cx, ci, cj = C.op.axis
       bx, bi, bj = B.op.axis
       # s[C].bind(ci, th_y)
       s[C].bind(cj, th_x)
       s[C].bind(cx, bk_x)
       s[B].compute_at(s[C], cx)
       s[B].bind(bi, th_y)
       s[B].bind(bj, th_x)
       
       print(tvm.lower(s, [A, C]))
       func = tvm.build(s, [A, C], target="cuda", name='tid')
       print(func.imported_modules[0].get_source())
   
   The three situations make a good summary.  
   1st one already has at least one test in 
tests/python/unittest/test_tir_transform_lower_warp_memory.py.
   I hope the above code can lock down the 2nd situation and probably the error 
for the 3rd one by reducing threadIdx.x's extent.
   
   Warp has been special cased in several places, e.g. in bound.cc and here 
before this PR.  I tried to push back to add more special case code, but I am 
Ok the accept the current change.  Please try to add tests.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to