junrushao commented on PR #16361:
URL: https://github.com/apache/tvm/pull/16361#issuecomment-1885967474

   > The scheduling primitives shouldn't be able to impact the output values, 
only how the way in which the outputs are computed. Since there are definitely 
cases in which producing a nested thread binding changes the result, I was 
hoping there would be an alternative way to express cooperative fetching 
without requiring the nested thread binding.
   
   @Lunderberg This is a good point! Actually when designing TensorIR, 
@spectrometerHBH has put lots of thoughts in it specifically for cooperative 
fetching, and he has a proof that nest thread bindings won't change the 
semantics of a TensorIR during scheduling stage if it satisfies "compact 
dataflow" condition. Therefore, I wouldn't worry too much in this particular 
case.
   
   In fact, in a simplified non-nested case below, the TIR describes two 
separate kernels, but when creating a TIR schedule, it errors out because 
`blockIdx.x` are different in different blocks. This is actually a valid case 
because the split-host-device pass later will split it into two kernels.
   
   ```python
   @T.prim_func
   def two_kernels(var_A: T.handle, var_B: T.handle, seq_len: T.int32):
       T.func_attr({"tir.noalias": T.bool(True)})
       A = T.match_buffer(var_A, (1, seq_len * 8), "int32")
       B = T.match_buffer(var_B, (1, seq_len * 8), "int32", align=8)
       with T.block("exclusive_scan"):
           T.reads()
           T.writes()
           s8: T.int32 = seq_len * 8
           if s8 == 0:
               blockIdx_x = T.launch_thread("blockIdx.x", 1)
           else:
               with T.launch_thread("threadIdx.x", 1024) as threadIdx_x:
                   blockIdx_x = T.launch_thread("blockIdx.x", T.ceildiv(s8, 
1024))
                   i: T.int32 = blockIdx_x * 1024 + threadIdx_x
                   if i < s8:
                       B[i // s8, i % s8] = A[i // s8, i % s8]
   ```
   
   I added this as a testcase


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to