yzh119 opened a new issue #6596:
URL: https://github.com/apache/incubator-tvm/issues/6596


   Currently TVM's boundary check avoids some invalid global memory access, it 
ignores the case when the arguments in `reduce_axis` requires global memory 
accessing (to an index tensor, this is common when dealing with sparse 
tensor/ragged tensors).
   
   Below is a simple example (segment sum) to reproduce the problem, what it 
did is basically is:
   - given a data tensor `x` and a offset(indicates the segment information) 
tensor `offsets`
   - For each segment index `i`, compute the sum of elements inside segment in 
`x`: `sum(x[offsets[i]:offsets[i+1]])`, and store the results in `out[i]`.
   ```
   import tvm
   import tvm.te as te
   
   num_elements = te.var('num_elements', dtype='int32')
   num_segments = te.var('num_elements', dtype='int32')
   x = te.placeholder((num_elements,), dtype='float32', name='x')
   offsets = te.placeholder((num_segments + 1), dtype='int32', name='offsets')
   
   def segment_sum(i):
       """Compute sum(x[offsets[i]:offsets[i+1]])"""
       k = te.reduce_axis((0, offsets[i + 1] - offsets[i]))
       return te.sum(x[k + offsets[i]], axis=k)
   
   out = te.compute(
       (num_segments,),
       segment_sum,
       name='out'
   )
   
   s = te.create_schedule(out.op)
   segment_axis = out.op.axis[0]
   segment_outer, segment_inner = s[out.op].split(segment_axis, factor=4)
   s[out.op].bind(segment_inner, te.thread_axis('threadIdx.x'))
   s[out.op].bind(segment_outer, te.thread_axis('blockIdx.x'))
   
   print(tvm.lower(s, [x, offsets, out]))
   ```
   
   Below is the generated code
   ```
   primfn(x_1: handle, offsets_1: handle, out_1: handle) -> ()
     attr = {"global_symbol": "main", "tir.noalias": True}
     buffers = {out: Buffer(out_2: Pointer(float32), float32, [num_elements: 
int32], [stride: int32], type="auto"),
                x: Buffer(x_2: Pointer(float32), float32, [num_elements_1: 
int32], [stride_1: int32], type="auto"),
                offsets: Buffer(offsets_2: Pointer(int32), int32, 
[(num_elements + 1)], [])}
     buffer_map = {x_1: x, offsets_1: offsets, out_1: out} {
     attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] 
"thread_extent" = floordiv((num_elements + 3), 4);
     attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", 
"threadIdx.x")] "thread_extent" = 4;
     if (blockIdx.x < floordiv(num_elements, 4)) {
       out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] = 0f32
       for (rv: int32, 0, ((int32*)offsets_2[(((blockIdx.x*4) + threadIdx.x) + 
1)] - (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])) {
         if (((blockIdx.x*4) + threadIdx.x) < num_elements) {
           out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] = 
((float32*)out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] + (float32*)x_2[((rv 
+ (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])*stride_1)])
         }
       }
     } else {
       if (((blockIdx.x*4) + threadIdx.x) < num_elements) {
         out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] = 0f32
       }
       for (rv_1: int32, 0, ((int32*)offsets_2[(((blockIdx.x*4) + threadIdx.x) 
+ 1)] - (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])) {
         if (((blockIdx.x*4) + threadIdx.x) < num_elements) {
           out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] = 
((float32*)out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] + 
(float32*)x_2[((rv_1 + (int32*)offsets_2[((blockIdx.x*4) + 
threadIdx.x)])*stride_1)])
         }
       }
     }
   }
   ```
   
   Note that in `for (rv_1: int32, 0, ((int32*)offsets_2[(((blockIdx.x*4) + 
threadIdx.x) + 1)] - (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])) {`, 
the memory access to offsets_2 is not protected thus incurring invalid memory 
access error when `((blockIdx.x*4) + threadIdx.x)` is greater then 
`num_elements`.
   
   If we change the order of the if-statement and the for-loop, the program 
should work correctly:
   ```
   if (((blockIdx.x*4) + threadIdx.x) < num_elements) {
     for (rv_1: int32, 0, ((int32*)offsets_2[(((blockIdx.x*4) + threadIdx.x) + 
1)] - (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])) {
       out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] = 
((float32*)out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] + 
(float32*)x_2[((rv_1 + (int32*)offsets_2[((blockIdx.x*4) + 
threadIdx.x)])*stride_1)])
     }
   }
   ```
   The bug was also mentioned in [TVM 
forum](https://discuss.tvm.apache.org/t/tvm-access-beyond-array-boundary/6998).
   
   I think this error is related to 
https://github.com/apache/incubator-tvm/blob/f13fed55cfe872ba7f40970f6a35f965d186a30a/src/tir/transforms/bound_checker.cc,
 I wonder how could I change it to be aware of global memory access in 
`reduce_axis`?
   
   cc @junrushao1994 , @hzfan 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to