roastduck opened a new issue #5559:
URL: https://github.com/apache/incubator-tvm/issues/5559


   `HoistIfThenElse` is a pass currently not enabled in TVM. I tried to enable 
it in #5553, but there are too many bugs in this pass. Let's fix them first.
   
   **BUG 1:** `HoistIfThenElse` transforms
   
   ```
   for (n.inner, 0, 2) {
     for (o.inner, 0, 2) {
       if ((((threadIdx.y*2) + n.inner) < 2)) {
         if ((((threadIdx.z*2) + o.inner) < 4)) {
           if ((threadIdx.y < 1)) {
             if ((threadIdx.z < 2)) {
               tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16, 
((n.inner*2) + o.inner), tvm_access_ptr(type_annotation(), Conv, 
(((((threadIdx.y*401408) + (n.inner*200704)) + (blockIdx.z*1024)) + 
(threadIdx.z*512)) + (o.inner*256)), 256, 2), 16, "row_major")
             }
           }
         }
       }
     }
   }
   ```
   
   into
   
   ```
   if ((((threadIdx.y*2) + n.inner) < 2)) {
     if ((threadIdx.y < 1)) {
       if ((threadIdx.z < 2)) {
         for (n.inner, 0, 2) {
           for (o.inner, 0, 2) {
             if ((((threadIdx.z*2) + o.inner) < 4)) {
               tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16, 
((n.inner*2) + o.inner), tvm_access_ptr(type_annotation(), Conv, 
(((((threadIdx.y*401408) + (n.inner*200704)) + (blockIdx.z*1024)) + 
(threadIdx.z*512)) + (o.inner*256)), 256, 2), 16, "row_major")
             }
           }
         }
       }
     }
   }
   ```
   
   Possible cause:
   
   
https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/src/tir/pass/hoist_if_then_else.cc#L295
   
   It only checks whether `if_stmt` has a preferred position, but that position 
is not guaranteed to be the current position. Change it to
   
   ```c++
   if (if_position_map.count(if_stmt.get()) &&
       if_position_map.at(if_stmt.get()).as<ForNode>()->loop_var.get() == 
top_for_var) {
   ```
   
   may solve the problem.
   
   **BUG 2:** `src/tir/transforms/split_host_device.cc` want the IR to be an 
SSA form, where each variable can only be defined once. Since we are copying 
loops into both "then" branches and "else" branches, we have to rename the loop 
variables in "else" branches to be different from those in "then" branches. I 
have already written some code for this, see #5553.
   
   **BUG 3:**  `IfThenElse` nodes containing thread indices should not be 
hoisted over the definition of the indices. This would happen when `Attr` node 
for `thread_extent` is scheduled into the body of a `For` node, using a 
`compute_at` command. I have already written some code for this, see #5553.
   
   **BUG 4:** 
   
   
https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/src/tir/pass/hoist_if_then_else.cc#L371
   
   Look at this line. `if_stmt` can already been updated when running this 
line. Look at the example below.
   
   ```
   for (i, 0, 10) {
     for (j, 0, 10) {
       for (k, 0, 10) {
         if ((i >= 3)) {
           if ((j >= 3)) {
             data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 
0.5f)
           }
         }
       }
     }
   }
   ```
   
   After hoisting `j >= 3`, if becomes
   
   ```
   for (i, 0, 10) {
     for (j, 0, 10) {
       if ((j >= 3)) {
         for (k, 0, 10) {
           if ((i >= 3)) {
             data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 
0.5f)
           }
         }
       }
     }
   }
   ```
   
   Now, when we are hoisting `i >= 3`, we need to compare and remove
   
   ```
   if ((i >= 3)) {
     if ((j >= 3)) {
       data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
     }
   }
   ```
   
   But `j >= 3` has been gone, so `RemoveIf` fails. We have to track the 
updating to `IfThenElse` just like what we did for `For`.
   
   **BUG 4:** It is for tests this time.
   
   
https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/tests/python/unittest/test_tir_pass_hoist_if.py#L175
   
   Why do we expect a `('For', 'j')` inside itself? As a potential problem, 
maybe we should change the variable names to prevent there are two `i`s and two 
`j`s.
   
   These are all the bugs I found.
   
   Beside, I suggest changing all the `for (size_t i = 0; i < xxx.size(); i++)` 
into `for (size_t i = 0, n = xxx.size(); i < n; i++)`, since C++ compiler can't 
detect this loop invariant.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to