wejoncy opened a new issue #7247:
URL: https://github.com/apache/tvm/issues/7247
## Description
TVM will generate float type of index expr, like
`B_local[((((float)get_local_id(1)) + ((float)(((int)get_group_id(2)) *
8))))];` , which is illegal.
## Error information
```
<kernel>:41:227: error: array subscript is not an integer
B[((((float)((((((int)get_group_id(0)) * 131072) + (yy_inner * 16384)) +
(((int)get_group_id(1)) * 4096)) + (((int)get_local_id(0)) * 512))) +
(((float)get_local_id(1)) + ((float)(((int)get_group_id(2)) * 8)))))] =
B_local[((((float)get_local_id(1)) + ((float)(((int)get_group_id(2)) * 8))))];
```
## Env
TVM::main
## Minimal code to reproduce.
```
import numpy as np
import tvm
from tvm import te
# The sizes of inputs and filters
batch = 1
in_channel = 256
out_channel = 512
in_size = 32
kernel = 1
pad = 0
stride = 1
# Algorithm
A = te.placeholder((in_size, in_size, in_channel, batch), name="A")
W = te.placeholder((kernel, kernel, in_channel, out_channel), name="W")
out_size = (in_size - kernel + 2 * pad) // stride + 1
# Pad input
Apad = te.compute(
(in_size + 2 * pad, in_size + 2 * pad, in_channel, batch),
lambda yy, xx, cc, nn: tvm.tir.if_then_else(
tvm.tir.all(yy >= pad, yy - pad < in_size, xx >= pad, xx - pad <
in_size),
A[yy - pad, xx - pad, cc, nn],
tvm.tir.const(0.0, "float32"),
),
name="Apad",
)
# Create reduction variables
rc = te.reduce_axis((0, in_channel), name="rc")
ry = te.reduce_axis((0, kernel), name="ry")
rx = te.reduce_axis((0, kernel), name="rx")
# Compute the convolution
B = te.compute(
(out_size, out_size, out_channel, batch),
lambda yy, xx, ff, nn: te.sum(
Apad[yy * stride + ry, xx * stride + rx, rc, nn] * W[ry, rx, rc,
ff], axis=[ry, rx, rc]
),
name="B",
)
# Designate the memory hierarchy
s = te.create_schedule(B.op)
s[Apad].compute_inline() # compute Apad inline
AA = s.cache_read(Apad, "shared", [B])
WW = s.cache_read(W, "shared", [B])
AL = s.cache_read(AA, "local", [B])
WL = s.cache_read(WW, "local", [B])
BL = s.cache_write(B, "local")
# tile consts
tile = 1
num_thread = 8
block_factor = tile * num_thread
step = 8
vthread = 2
# Get the GPU thread indices
block_x = te.thread_axis("blockIdx.x")
block_y = te.thread_axis("blockIdx.y")
block_z = te.thread_axis("blockIdx.z")
thread_x = te.thread_axis((0, num_thread), "threadIdx.x")
thread_y = te.thread_axis((0, num_thread/4), "threadIdx.y")
thread_xz = te.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = te.thread_axis((0, vthread), "vthread", name="vy")
# Split the workloads
hi, wi, fi, ni = s[B].op.axis
bz, fi = s[B].split(fi, factor=block_factor)
bx, hi = s[B].split(hi, factor=block_factor)
by, wi = s[B].split(wi, factor=block_factor)
# Bind the iteration variables to GPU thread indices
s[B].bind(bz, block_z)
s[B].bind(by, block_y)
s[B].bind(bx, block_x)
ty, fi = s[B].split(fi, nparts=num_thread)
tx, wi = s[B].split(wi, nparts=num_thread)
s[B].reorder(bz,by,bx,ty,wi, fi)
s[B].bind(ty, thread_y)
s[B].bind(tx, thread_x)
# Schedule BL local write
s[BL].compute_at(s[B],ni)
fi, noo = s[B].split(fi, factor=4)
s[B].vectorize(noo) # vectorize memory load
yi, xi, fi, ni = s[BL].op.axis
ry, rx, rc = s[BL].op.reduce_axis
rco, rci = s[BL].split(rc, factor=step)
s[BL].reorder(rco, ry, rx, rci, fi, ni)
# Attach computation to iteration variables
s[AA].compute_at(s[BL], rx)
s[WW].compute_at(s[BL], rx)
s[AL].compute_at(s[BL], rci)
s[WL].compute_at(s[BL], rci)
yi, xi, ci, ni = s[AA].op.axis
tx, ni = s[AA].split(ni, nparts=num_thread)
_, ci = s[AA].split(ci, factor=4)
s[AA].bind(tx, thread_x)
s[AA].vectorize(ci) # vectorize memory load
# Schedule for W's shared memory load
yi,xi,ci,fi=s[WW].op.axis
ty, ci = s[WW].split(ci, nparts=num_thread)
fi, fv = s[WW].split(fi, factor=4)
s[WW].bind(ty, thread_y)
s[WW].vectorize(fv) # vectorize memory load
target="opencl"
func = tvm.build(s, [A, W, B], target)
print("------opencl code------")
print(func.imported_modules[0].get_source()) if len(func.imported_modules) >
0 else print("source not imported")
ctx = tvm.context(target, 0)
np.random.seed(5)
a_np = np.random.uniform(size=(in_size, in_size, in_channel,
batch)).astype("float32")
w_np = np.random.uniform(size=(kernel, kernel, in_channel,
out_channel)).astype("float32")
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros((out_size, out_size, out_channel, batch),
dtype="float32"), ctx)
func(a, w, b)
np.savetxt("filename.txt",b.asnumpy()[:,:,0,0])
evaluator = func.time_evaluator(func.entry_name, ctx, number=5)
print("Convolution: %f ms" % (evaluator(a, w, b).mean * 1e3))
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]