roastduck opened a new issue #5374: [TE] Warp memory in InferBound
URL: https://github.com/apache/incubator-tvm/issues/5374
 
 
   I'm working with a buffer that bound to warp scope. In 
`src/te/schedule/message_passing.cc:208`:
   
   ```c++
   PrimExpr outer = state.at(s->outer);
   PrimExpr inner = state.at(s->inner);
   PrimExpr factor = dom_map.at(s->inner)->extent;
   PrimExpr parent_min = dom_map.at(s->parent)->min;
   state[s->parent] = inner + outer * factor;
   // add min if they exist
   if (!is_zero(parent_min)) {
       state[s->parent] = state[s->parent] + parent_min;
   }
   ```
   
   I found `threadIdx.y` is presented both in `state[s->parent]` and 
`parent_min` in my application, so the result becomes `threadIdx.y + 
threadIdx.y + ...`, which leads to a wrong boundary checking in the end.
   
   I tracked down `state[s->parent]`. In `src/te/operation/op_util.cc:167`, 
there is a code piece that handles different thread indices for different 
storage scope:
   
   ```c++
   runtime::ThreadScope ts = runtime::ThreadScope::make(bind_iv->thread_tag);
   if (stage->scope == "" || stage->scope == "warp" ||
       static_cast<int>(runtime::StorageScope::make(stage->scope).rank) <= 
ts.rank) {
       value_map[iv] = var;
   } else {
       value_map[iv] = dom->min;
   }
   ```
   
   I think the purpose of the code above is like:
   
   - Both `threadIdx` and `blockIdx` should be indices of a global memory 
buffer.
   - `threadIdx` should be a indies of a shared memory buffer, but `blockIdx` 
should not.
   
   I think here is a defect on warp memory. `threadIdx.x` (suppose the extent 
of `threadIdx.x` equals to the warp size) should indeed be a index of a warp 
buffer, but `threadIdx.y` should not. Currently it seems that both 
`threadIdx.x` and `threadIdx.y` are counted as indices.
   
   I have not figured out the whole picture yet, and I have not constructed a 
simple enough counter-example. I think the code piece above is not the only 
code that handles warp memory in bound inference. Where does `parent_min` 
decided? And should be consider the situation that extent of `threadIdx.x` < 
warp size?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to