lazycal opened a new issue #9598:
URL: https://github.com/apache/tvm/issues/9598
```python
import tvm
from tvm import te
import numpy as np
import tvm.testing
F = 100
N = F + 1
A = te.placeholder((N, N), name="A")
k = te.reduce_axis((0, N), name="k")
B = te.compute((N,), lambda i: te.sum(A[i, k], k), name="B")
C = te.compute((N,), lambda i: B[i], name="C")
s = te.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=F)
s[B].compute_at(s[C], xi)
foo = tvm.build(s, [A, B, C], "llvm")
print(tvm.lower(s, [A, B, C], simple_mode=True))
anp = tvm.nd.array(np.random.uniform(
size=(N, N)).astype(A.dtype), tvm.cpu())
bnp = tvm.nd.array(np.random.uniform(
size=(N,)).astype(A.dtype), tvm.cpu())
cnp = tvm.nd.array(np.random.uniform(
size=(N,)).astype(A.dtype), tvm.cpu())
foo(anp, bnp, cnp)
tvm.testing.assert_allclose(bnp.asnumpy(), cnp.asnumpy())
```
This triggers segmentation fault. The produced IR is
```c
@main = primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main",
"tir.noalias": True}
buffers = {C: Buffer(C_2: Pointer(float32), float32, [101], []),
A: Buffer(A_2: Pointer(float32), float32, [101, 101], []),
B: Buffer(B_2: Pointer(float32), float32, [101], [])}
buffer_map = {A_1: A, B_1: B, C_1: C} {
for (i.outer: int32, 0, 2) {
for (i.inner: int32, 0, 100) {
B_2[((i.outer*100) + i.inner)] = 0f32
if @tir.likely((((i.outer*100) + i.inner) < 101), dtype=bool) {
for (k: int32, 0, 101) {
B_2[((i.outer*100) + i.inner)] = ((float32*)B_2[((i.outer*100) +
i.inner)] + (float32*)A_2[(((i.outer*10100) + (i.inner*101)) + k)])
}
}
if @tir.likely((((i.outer*100) + i.inner) < 101), dtype=bool) {
C_2[((i.outer*100) + i.inner)] = (float32*)B_2[((i.outer*100) +
i.inner)]
}
}
}
}
```
where `B_2[((i.outer*100) + i.inner)] = 0f32` isn't wrapped with the
predicate as in the reduction body.
## Investigation
The problem can be solved if we do not skip the bound check by replacing
`!stage->rolling_buffer` with `false` in
https://github.com/apache/tvm/blob/adf560ebed8465c22bf58f406d0a8d20663cdd1d/src/te/operation/compute_op.cc#L488.
However, I'm not sure if this is the right fix as I am having trouble
understanding the logic of bound checking. The part that confuses me is why the
reduction body does not skip the bound checks (shown in
https://github.com/apache/tvm/blob/adf560ebed8465c22bf58f406d0a8d20663cdd1d/src/te/operation/compute_op.cc#L447)
but the init skips it.
I see that there are two types (L550-L560 and L561-577) of bound checks in
the MakeBoundCheck function
https://github.com/apache/tvm/blob/adf560ebed8465c22bf58f406d0a8d20663cdd1d/src/te/schedule/message_passing.cc#L526-L580
and passing `false` to `skip_ivar_domain ` only disables the second one. But
the first check seems not comprehensive: in the above code, due to the
compute_at `B`'s axis is "implicitly" binded to a split axis of `C`, but the
first check cannot see the split relation. As a result `PassUpBoundCheck`
doesn't mark it as needing checks. So I'm also curious whehter this is expected
or not.
### Environment
OS: Ubuntu 18.04
TVM Version: ecd8a9ce33991262f4184cb857f1088d9d8e1bb1
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]