twzhyyxwhez31057 opened a new issue, #15936:
URL: https://github.com/apache/tvm/issues/15936

   Hello, I am learning how to deploy to adreno based on this 
document:https://tvm.apache.org/docs/how_to/deploy/adreno.html. I tried to use 
this pass to export the TIR after tuning and make some adjustments on it.
   ```
   @tvm.tir.transform.prim_func_pass(opt_level=0)
       def print_tir(f, mod, ctx):
           print(f)
           return f
   ```
    However I found that the exported TIR has some "has no attribute" problems 
, such as T.nd_mem_alloc_with_scope, T.texture2d_load, resulting in the 
inability to directly build the TIR. How can I build OpenCL code that supports 
texture directly from TIR?
   Part of the TIR I extracted from relay:
   ```
   @T.prim_func
   def tvmgen_default_fused_nn_conv2d_add_nn_relu(p0: T.Buffer((1, 3, 224, 
224), "float16"), p1: T.Buffer((16, 3, 7, 7, 4), "float16"), p2: T.Buffer((1, 
16, 1, 1, 4), "float16"), T_relu: T.Buffer((1, 16, 112, 112, 4), "float16")):
       T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": 
T.bool(True)})
       input_pack = T.allocate([200704], "float16", "global")
       input_pack_1 = T.Buffer((200704,), "float16", data=input_pack)
       with T.launch_thread("blockIdx.x", 784) as blockIdx_x:
           threadIdx_x = T.launch_thread("threadIdx.x", 64)
           for i4_s in range(4):
               p0_1 = T.Buffer((150528,), "float16", data=p0.data)
               input_pack_1[blockIdx_x * 256 + threadIdx_x * 4 + i4_s] = 
T.if_then_else(i4_s == 3, T.float16(0), p0_1[i4_s * 50176 + blockIdx_x * 64 + 
threadIdx_x])
       pad_temp_global_texture: T.handle("float16", "global.texture") = 
T.nd_mem_alloc_with_scope("global.texture", T.int64(2), 
T.tvm_stack_make_shape(229, 229))
       with T.launch_thread("blockIdx.x", 52441) as blockIdx_x:
           threadIdx_x = T.launch_thread("threadIdx.x", 1)
           T.texture2d_store(pad_temp_global_texture, blockIdx_x % 229, 
blockIdx_x // 229, T.if_then_else(687 <= blockIdx_x and blockIdx_x < 51983 and 
3 <= blockIdx_x % 229 and blockIdx_x % 229 < 227, input_pack_1[blockIdx_x // 
229 * 896 + blockIdx_x % 229 * 4 - 2700:blockIdx_x // 229 * 896 + blockIdx_x % 
229 * 4 - 2700 + 4], T.Broadcast(T.float16(0), 4)))
       blockIdx_z = T.launch_thread("blockIdx.z", 1)
       compute = T.allocate([8], "float16x4", "local")
       blockIdx_y = T.launch_thread("blockIdx.y", 4)
       blockIdx_x = T.launch_thread("blockIdx.x", 14)
       threadIdx_z = T.launch_thread("threadIdx.z", 16)
       threadIdx_y = T.launch_thread("threadIdx.y", 7)
       threadIdx_x = T.launch_thread("threadIdx.x", 4)
       compute_1 = T.Buffer((32,), "float16x4", data=compute, scope="local", 
align=8)
       compute_1[0] = T.Broadcast(T.float16(0), 4)
       compute_1[2] = T.Broadcast(T.float16(0), 4)
       compute_1[4] = T.Broadcast(T.float16(0), 4)
       compute_1[6] = T.Broadcast(T.float16(0), 4)
       compute_1[1] = T.Broadcast(T.float16(0), 4)
       compute_1[3] = T.Broadcast(T.float16(0), 4)
       compute_1[5] = T.Broadcast(T.float16(0), 4)
       compute_1[7] = T.Broadcast(T.float16(0), 4)
       for ry_outer, rx_outer in T.grid(7, 7):
           for rc in T.unroll(4):
               p1_1 = T.Buffer((9408,), "float16", data=p1.data)
               compute_1[0] = compute_1[0] + 
T.texture2d_load(pad_temp_global_texture, blockIdx_x * 16 + threadIdx_x * 2 + 
rx_outer, blockIdx_y * 56 + threadIdx_y * 2 + ry_outer, rc) * p1_1[threadIdx_z 
* 588 + rc * 196 + ry_outer * 28 + rx_outer * 4:threadIdx_z * 588 + rc * 196 + 
ry_outer * 28 + rx_outer * 4 + 4]
               compute_1[2] = compute_1[2] + 
T.texture2d_load(pad_temp_global_texture, blockIdx_x * 16 + threadIdx_x * 2 + 
rx_outer, blockIdx_y * 56 + threadIdx_y * 2 + ry_outer + 14, rc) * 
p1_1[threadIdx_z * 588 + rc * 196 + ry_outer * 28 + rx_outer * 4:threadIdx_z * 
588 + rc * 196 + ry_outer * 28 + rx_outer * 4 + 4]
               compute_1[4] = compute_1[4] + 
T.texture2d_load(pad_temp_global_texture, blockIdx_x * 16 + threadIdx_x * 2 + 
rx_outer, blockIdx_y * 56 + threadIdx_y * 2 + ry_outer + 28, rc) * 
p1_1[threadIdx_z * 588 + rc * 196 + ry_outer * 28 + rx_outer * 4:threadIdx_z * 
588 + rc * 196 + ry_outer * 28 + rx_outer * 4 + 4]
               compute_1[6] = compute_1[6] + 
T.texture2d_load(pad_temp_global_texture, blockIdx_x * 16 + threadIdx_x * 2 + 
rx_outer, blockIdx_y * 56 + threadIdx_y * 2 + ry_outer + 42, rc) * 
p1_1[threadIdx_z * 588 + rc * 196 + ry_outer * 28 + rx_outer * 4:threadIdx_z * 
588 + rc * 196 + ry_outer * 28 + rx_outer * 4 + 4]
               compute_1[1] = compute_1[1] + 
T.texture2d_load(pad_temp_global_texture, blockIdx_x * 16 + threadIdx_x * 2 + 
rx_outer + 8, blockIdx_y * 56 + threadIdx_y * 2 + ry_outer, rc) * 
p1_1[threadIdx_z * 588 + rc * 196 + ry_outer * 28 + rx_outer * 4:threadIdx_z * 
588 + rc * 196 + ry_outer * 28 + rx_outer * 4 + 4]
               compute_1[3] = compute_1[3] + 
T.texture2d_load(pad_temp_global_texture, blockIdx_x * 16 + threadIdx_x * 2 + 
rx_outer + 8, blockIdx_y * 56 + threadIdx_y * 2 + ry_outer + 14, rc) * 
p1_1[threadIdx_z * 588 + rc * 196 + ry_outer * 28 + rx_outer * 4:threadIdx_z * 
588 + rc * 196 + ry_outer * 28 + rx_outer * 4 + 4]
               compute_1[5] = compute_1[5] + 
T.texture2d_load(pad_temp_global_texture, blockIdx_x * 16 + threadIdx_x * 2 + 
rx_outer + 8, blockIdx_y * 56 + threadIdx_y * 2 + ry_outer + 28, rc) * 
p1_1[threadIdx_z * 588 + rc * 196 + ry_outer * 28 + rx_outer * 4:threadIdx_z * 
588 + rc * 196 + ry_outer * 28 + rx_outer * 4 + 4]
               compute_1[7] = compute_1[7] + 
T.texture2d_load(pad_temp_global_texture, blockIdx_x * 16 + threadIdx_x * 2 + 
rx_outer + 8, blockIdx_y * 56 + threadIdx_y * 2 + ry_outer + 42, rc) * 
p1_1[threadIdx_z * 588 + rc * 196 + ry_outer * 28 + rx_outer * 4:threadIdx_z * 
588 + rc * 196 + ry_outer * 28 + rx_outer * 4 + 4]
       T_relu_1 = T.Buffer((802816,), "float16", data=T_relu.data)
       p2_1 = T.Buffer((64,), "float16", data=p2.data)
       T_relu_1[threadIdx_z * 50176 + blockIdx_y * 12544 + threadIdx_y * 448 + 
blockIdx_x * 32 + threadIdx_x * 4:threadIdx_z * 50176 + blockIdx_y * 12544 + 
threadIdx_y * 448 + blockIdx_x * 32 + threadIdx_x * 4 + 4] = T.max(compute_1[0] 
+ p2_1[threadIdx_z * 4:threadIdx_z * 4 + 4], T.Broadcast(T.float16(0), 4))
       T_relu_1[threadIdx_z * 50176 + blockIdx_y * 12544 + threadIdx_y * 448 + 
blockIdx_x * 32 + threadIdx_x * 4 + 3136:threadIdx_z * 50176 + blockIdx_y * 
12544 + threadIdx_y * 448 + blockIdx_x * 32 + threadIdx_x * 4 + 3136 + 4] = 
T.max(compute_1[2] + p2_1[threadIdx_z * 4:threadIdx_z * 4 + 4], 
T.Broadcast(T.float16(0), 4))
       T_relu_1[threadIdx_z * 50176 + blockIdx_y * 12544 + threadIdx_y * 448 + 
blockIdx_x * 32 + threadIdx_x * 4 + 6272:threadIdx_z * 50176 + blockIdx_y * 
12544 + threadIdx_y * 448 + blockIdx_x * 32 + threadIdx_x * 4 + 6272 + 4] = 
T.max(compute_1[4] + p2_1[threadIdx_z * 4:threadIdx_z * 4 + 4], 
T.Broadcast(T.float16(0), 4))
       T_relu_1[threadIdx_z * 50176 + blockIdx_y * 12544 + threadIdx_y * 448 + 
blockIdx_x * 32 + threadIdx_x * 4 + 9408:threadIdx_z * 50176 + blockIdx_y * 
12544 + threadIdx_y * 448 + blockIdx_x * 32 + threadIdx_x * 4 + 9408 + 4] = 
T.max(compute_1[6] + p2_1[threadIdx_z * 4:threadIdx_z * 4 + 4], 
T.Broadcast(T.float16(0), 4))
       T_relu_1[threadIdx_z * 50176 + blockIdx_y * 12544 + threadIdx_y * 448 + 
blockIdx_x * 32 + threadIdx_x * 4 + 16:threadIdx_z * 50176 + blockIdx_y * 12544 
+ threadIdx_y * 448 + blockIdx_x * 32 + threadIdx_x * 4 + 16 + 4] = 
T.max(compute_1[1] + p2_1[threadIdx_z * 4:threadIdx_z * 4 + 4], 
T.Broadcast(T.float16(0), 4))
       T_relu_1[threadIdx_z * 50176 + blockIdx_y * 12544 + threadIdx_y * 448 + 
blockIdx_x * 32 + threadIdx_x * 4 + 3152:threadIdx_z * 50176 + blockIdx_y * 
12544 + threadIdx_y * 448 + blockIdx_x * 32 + threadIdx_x * 4 + 3152 + 4] = 
T.max(compute_1[3] + p2_1[threadIdx_z * 4:threadIdx_z * 4 + 4], 
T.Broadcast(T.float16(0), 4))
       T_relu_1[threadIdx_z * 50176 + blockIdx_y * 12544 + threadIdx_y * 448 + 
blockIdx_x * 32 + threadIdx_x * 4 + 6288:threadIdx_z * 50176 + blockIdx_y * 
12544 + threadIdx_y * 448 + blockIdx_x * 32 + threadIdx_x * 4 + 6288 + 4] = 
T.max(compute_1[5] + p2_1[threadIdx_z * 4:threadIdx_z * 4 + 4], 
T.Broadcast(T.float16(0), 4))
       T_relu_1[threadIdx_z * 50176 + blockIdx_y * 12544 + threadIdx_y * 448 + 
blockIdx_x * 32 + threadIdx_x * 4 + 9424:threadIdx_z * 50176 + blockIdx_y * 
12544 + threadIdx_y * 448 + blockIdx_x * 32 + threadIdx_x * 4 + 9424 + 4] = 
T.max(compute_1[7] + p2_1[threadIdx_z * 4:threadIdx_z * 4 + 4], 
T.Broadcast(T.float16(0), 4))
   
   ```
   @srkreddy1238 @echuraev 
   
   ### Expected behavior
   
   Build OpenCL code that supports texture directly from TIR
   
   ### Actual behavior
   
   module 'tvm.script.tir' has no attribute 'nd_mem_alloc_with_scope'
   
   ### Environment
   
   opencl -device=adreno
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to