lazycal opened a new issue #8540:
URL: https://github.com/apache/tvm/issues/8540


   TVM seems to have some issue on the alignment requirement when doing 
vectorization on CUDA. Script to reproduce:
   ```python
   import tvm
   from tvm import relay
   import numpy as np
   from tvm.contrib import graph_executor
   
   
   def gen_one_random_feed(module):
       feed = {}
       for i in range(module.get_num_inputs()):
           data = np.random.rand(*module.get_input(i).shape)
           feed[i] = data
       return feed
   
   
   def run(mod):
       target = "cuda"
       lib = relay.build(mod, target=target)
       dev = tvm.cuda()
       rt_mod = graph_executor.GraphModule(lib["default"](dev))
       input_dict = gen_one_random_feed(rt_mod)
       for k, v in input_dict.items():
           rt_mod.set_input(k, v)
       rt_mod.run()
   
   
   def get_mod():
       x = relay.var('x', shape=(1, 3,), dtype='float16')
       t1 = relay.strided_slice(x, [0, 1], [1, 3], [1, 1])
       func = relay.Function([x], t1)
       return tvm.IRModule.from_expr(func)
   
   
   mod = get_mod()
   run(mod)
   print('Passed!')
   ```
   
   The generated CUDA kernel code:
   ```cuda
   extern "C" __global__ void tvmgen_default_fused_strided_slice_kernel0(half* 
__restrict__ T_strided_slice, half* __restrict__ placeholder) {
     if (((int)threadIdx.x) < 1) {
       ((uint1*)(T_strided_slice + ((((int)threadIdx.x) * 2))))[0] = 
((uint1*)(placeholder + (((((int)threadIdx.x) * 3) + 1))))[0];
     }
   }
   ```
   It reads 2 `half`s starting from an odd offset of `placeholder`, which is 
not aligned to 4 bytes (assuming `placeholder` is at least 4-byte aligned). 
   
   ## Some analysis and possible fix
   After some close look, I found that this was caused by the vectorization 
introduced here: 
https://github.com/apache/tvm/blob/main/python/tvm/topi/cuda/injective.py#L55. 
Setting `vector_width=1` avoids the issue. I also managed to craft a TE and 
schedule that triggers the bug, for better illustration:
   
   ```python
   import tvm
   from tvm import te
   import numpy as np
   
   def get_te_sch():
       N = 3
       C_N = N - 1
       A = te.placeholder((N,), name="A", dtype='float16')
       C = te.compute((C_N,), lambda i: A[i+1], name="C")
   
       s = te.create_schedule(C.op)
       oi, ii = s[C].split(C.op.axis[0], factor=2)
       s[C].bind(oi, te.thread_axis("threadIdx.x"))
       s[C].vectorize(ii) # BUG: misalignment
       return N, C_N, s, A, C
   
   def build_run():
       N, C_N, s, A, C = get_te_sch()
   
       tgt = tvm.target.Target(target="cuda", host="llvm")
       foo = tvm.build(s, [A, C], tgt, name="foo")
       dev = tvm.device(tgt.kind.name, 0)
   
       a_data = np.arange(0, N).astype(A.dtype)
       a = tvm.nd.array(a_data, dev)
       c = tvm.nd.array(np.zeros(C_N, dtype=C.dtype), dev)
       foo(a, c)
       expected = a_data[1:C_N+1]
       assert np.allclose(c.numpy(), expected), 
f"expected={expected}\nactual={c}"
   
   build_run()
   print('passed')
   ```
   Notice that the compute of C is a vector load starting from an not 4-byte 
aligned address. This gets lowered into 
   ```cpp
     attr = {"global_symbol": "foo", "tir.noalias": True}
     buffers = {C: Buffer(C_2: Pointer(float16), float16, [2], []),
                A: Buffer(A_2: Pointer(float16), float16, [3], [])}
     buffer_map = {A_1: A, C_1: C} {
     realize(C, [0:2], True {
       attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", 
"threadIdx.x")] "thread_extent" = 1;
       for (i.inner: int32, 0, 2) "vectorized" {
         C[i.inner] = A[(i.inner + 1)]
       }
     })
   }
   ```
   followed by the `VectorizeLoop` pass that vectorizes the loop into 
``C_2[ramp(0, 1, 2)] = (float16x2*)A_2[ramp(1, 1, 2)]``, which then gets 
codegen-ed into misaligned vector load of `A_2[1:3]`.
   
   I'm not sure which pass should handle the alignments? I suspect it should be 
the reponsibility of codegen? Because at 
[CodeGenLLVM](https://github.com/apache/tvm/blob/07243a89df3cdac6afaf29a74aba398190a76c72/src/target/llvm/codegen_llvm.cc#L1167)
 I do see alignment handling, but do not see at [CodeGenCUDA's visit 
LoadNode](https://github.com/apache/tvm/blob/07243a89df3cdac6afaf29a74aba398190a76c72/src/target/source/codegen_c.cc#L701).
 I guess this [if 
branch](https://github.com/apache/tvm/blob/07243a89df3cdac6afaf29a74aba398190a76c72/src/target/source/codegen_c.cc#L711-L713)
 is generating vector load, so maybe a fix would be adding alignment check at 
its condition? The else branch does not generate vector load so disabling the 
if-branch also works fine in my case.
   
   
   ## Environment
   - TVM: commit 07243a89df3cdac6afaf29a74aba398190a76c72
   - CUDA version: 10.0
   - System: Ubuntu 16.04
   - GCC 5.4
   - Build options: ``-DUSE_RELAY_DEBUG=ON -DUSE_CUBLAS=ON -DUSE_LLVM=ON 
-DUSE_CUDA=ON``
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to