roastduck opened a new issue #5559:
URL: https://github.com/apache/incubator-tvm/issues/5559
`HoistIfThenElse` is a pass currently not enabled in TVM. I tried to enable
it in #5553, but there are too many bugs in this pass. Let's fix them first.
**BUG 1:** `HoistIfThenElse` transforms
```
for (n.inner, 0, 2) {
for (o.inner, 0, 2) {
if ((((threadIdx.y*2) + n.inner) < 2)) {
if ((((threadIdx.z*2) + o.inner) < 4)) {
if ((threadIdx.y < 1)) {
if ((threadIdx.z < 2)) {
tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16,
((n.inner*2) + o.inner), tvm_access_ptr(type_annotation(), Conv,
(((((threadIdx.y*401408) + (n.inner*200704)) + (blockIdx.z*1024)) +
(threadIdx.z*512)) + (o.inner*256)), 256, 2), 16, "row_major")
}
}
}
}
}
}
```
into
```
if ((((threadIdx.y*2) + n.inner) < 2)) {
if ((threadIdx.y < 1)) {
if ((threadIdx.z < 2)) {
for (n.inner, 0, 2) {
for (o.inner, 0, 2) {
if ((((threadIdx.z*2) + o.inner) < 4)) {
tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16,
((n.inner*2) + o.inner), tvm_access_ptr(type_annotation(), Conv,
(((((threadIdx.y*401408) + (n.inner*200704)) + (blockIdx.z*1024)) +
(threadIdx.z*512)) + (o.inner*256)), 256, 2), 16, "row_major")
}
}
}
}
}
}
```
Possible cause:
https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/src/tir/pass/hoist_if_then_else.cc#L295
It only checks whether `if_stmt` has a preferred position, but that position
is not guaranteed to be the current position. Change it to
```c++
if (if_position_map.count(if_stmt.get()) &&
if_position_map.at(if_stmt.get()).as<ForNode>()->loop_var.get() ==
top_for_var) {
```
may solve the problem.
**BUG 2:** `src/tir/transforms/split_host_device.cc` want the IR to be an
SSA form, where each variable can only be defined once. Since we are copying
loops into both "then" branches and "else" branches, we have to rename the loop
variables in "else" branches to be different from those in "then" branches. I
have already written some code for this, see #5553.
**BUG 3:** `IfThenElse` nodes containing thread indices should not be
hoisted over the definition of the indices. This would happen when `Attr` node
for `thread_extent` is scheduled into the body of a `For` node, using a
`compute_at` command. I have already written some code for this, see #5553.
**BUG 4:**
https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/src/tir/pass/hoist_if_then_else.cc#L371
Look at this line. `if_stmt` can already been updated when running this
line. Look at the example below.
```
for (i, 0, 10) {
for (j, 0, 10) {
for (k, 0, 10) {
if ((i >= 3)) {
if ((j >= 3)) {
data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] +
0.5f)
}
}
}
}
}
```
After hoisting `j >= 3`, if becomes
```
for (i, 0, 10) {
for (j, 0, 10) {
if ((j >= 3)) {
for (k, 0, 10) {
if ((i >= 3)) {
data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] +
0.5f)
}
}
}
}
}
```
Now, when we are hoisting `i >= 3`, we need to compare and remove
```
if ((i >= 3)) {
if ((j >= 3)) {
data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
}
}
```
But `j >= 3` has been gone, so `RemoveIf` fails. We have to track the
updating to `IfThenElse` just like what we did for `For`.
**BUG 4:** It is for tests this time.
https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/tests/python/unittest/test_tir_pass_hoist_if.py#L175
Why do we expect a `('For', 'j')` inside itself? As a potential problem,
maybe we should change the variable names to prevent there are two `i`s and two
`j`s.
These are all the bugs I found.
Beside, I suggest changing all the `for (size_t i = 0; i < xxx.size(); i++)`
into `for (size_t i = 0, n = xxx.size(); i < n; i++)`, since C++ compiler can't
detect this loop invariant.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]