yzh119 opened a new issue #6596:
URL: https://github.com/apache/incubator-tvm/issues/6596
Currently TVM's boundary check avoids some invalid global memory access, it
ignores the case when the arguments in `reduce_axis` requires global memory
accessing (to an index tensor, this is common when dealing with sparse
tensor/ragged tensors).
Below is a simple example (segment sum) to reproduce the problem, what it
did is basically is:
- given a data tensor `x` and a offset(indicates the segment information)
tensor `offsets`
- For each segment index `i`, compute the sum of elements inside segment in
`x`: `sum(x[offsets[i]:offsets[i+1]])`, and store the results in `out[i]`.
```
import tvm
import tvm.te as te
num_elements = te.var('num_elements', dtype='int32')
num_segments = te.var('num_elements', dtype='int32')
x = te.placeholder((num_elements,), dtype='float32', name='x')
offsets = te.placeholder((num_segments + 1), dtype='int32', name='offsets')
def segment_sum(i):
"""Compute sum(x[offsets[i]:offsets[i+1]])"""
k = te.reduce_axis((0, offsets[i + 1] - offsets[i]))
return te.sum(x[k + offsets[i]], axis=k)
out = te.compute(
(num_segments,),
segment_sum,
name='out'
)
s = te.create_schedule(out.op)
segment_axis = out.op.axis[0]
segment_outer, segment_inner = s[out.op].split(segment_axis, factor=4)
s[out.op].bind(segment_inner, te.thread_axis('threadIdx.x'))
s[out.op].bind(segment_outer, te.thread_axis('blockIdx.x'))
print(tvm.lower(s, [x, offsets, out]))
```
Below is the generated code
```
primfn(x_1: handle, offsets_1: handle, out_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {out: Buffer(out_2: Pointer(float32), float32, [num_elements:
int32], [stride: int32], type="auto"),
x: Buffer(x_2: Pointer(float32), float32, [num_elements_1:
int32], [stride_1: int32], type="auto"),
offsets: Buffer(offsets_2: Pointer(int32), int32,
[(num_elements + 1)], [])}
buffer_map = {x_1: x, offsets_1: offsets, out_1: out} {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")]
"thread_extent" = floordiv((num_elements + 3), 4);
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex",
"threadIdx.x")] "thread_extent" = 4;
if (blockIdx.x < floordiv(num_elements, 4)) {
out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] = 0f32
for (rv: int32, 0, ((int32*)offsets_2[(((blockIdx.x*4) + threadIdx.x) +
1)] - (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])) {
if (((blockIdx.x*4) + threadIdx.x) < num_elements) {
out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] =
((float32*)out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] + (float32*)x_2[((rv
+ (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])*stride_1)])
}
}
} else {
if (((blockIdx.x*4) + threadIdx.x) < num_elements) {
out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] = 0f32
}
for (rv_1: int32, 0, ((int32*)offsets_2[(((blockIdx.x*4) + threadIdx.x)
+ 1)] - (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])) {
if (((blockIdx.x*4) + threadIdx.x) < num_elements) {
out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] =
((float32*)out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] +
(float32*)x_2[((rv_1 + (int32*)offsets_2[((blockIdx.x*4) +
threadIdx.x)])*stride_1)])
}
}
}
}
```
Note that in `for (rv_1: int32, 0, ((int32*)offsets_2[(((blockIdx.x*4) +
threadIdx.x) + 1)] - (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])) {`,
the memory access to offsets_2 is not protected thus incurring invalid memory
access error when `((blockIdx.x*4) + threadIdx.x)` is greater then
`num_elements`.
If we change the order of the if-statement and the for-loop, the program
should work correctly:
```
if (((blockIdx.x*4) + threadIdx.x) < num_elements) {
for (rv_1: int32, 0, ((int32*)offsets_2[(((blockIdx.x*4) + threadIdx.x) +
1)] - (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])) {
out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] =
((float32*)out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] +
(float32*)x_2[((rv_1 + (int32*)offsets_2[((blockIdx.x*4) +
threadIdx.x)])*stride_1)])
}
}
```
The bug was also mentioned in [TVM
forum](https://discuss.tvm.apache.org/t/tvm-access-beyond-array-boundary/6998).
I think this error is related to
https://github.com/apache/incubator-tvm/blob/f13fed55cfe872ba7f40970f6a35f965d186a30a/src/tir/transforms/bound_checker.cc,
I wonder how could I change it to be aware of global memory access in
`reduce_axis`?
cc @junrushao1994 , @hzfan
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]