roastduck opened a new issue #5366: [TIR] lower_warp_memory cannot handle >1 warp buffers URL: https://github.com/apache/incubator-tvm/issues/5366 Pass `lower_warp_memory` cannot handle more than one warp buffers. Buffers except the first one cannot be correctly transformed to warp shuffles. To reproduce: ```python import tvm [8/1976] import topi import numpy as np from tvm import te dtype = "float32" target = "cuda" m = 32 A = te.placeholder((m,), name='A', dtype=dtype) B = te.placeholder((m,), name='B', dtype=dtype) C = te.compute((m,), lambda i: A[(i + 1) % m] + B[(i + 1) % m], name='C') cuda_target = tvm.target.create("cuda") assert m <= cuda_target.thread_warp_size with cuda_target: s = te.create_schedule(C.op) tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") AA = s.cache_read(A, "warp", [C]) BB = s.cache_read(B, "warp", [C]) xo, xi = s[C].split(C.op.axis[0], nparts=1) s[C].bind(xi, tx) s[C].bind(xo, bx) s[AA].compute_at(s[C], xo) s[BB].compute_at(s[C], xo) xo, xi = s[AA].split(s[AA].op.axis[0], nparts=1) s[AA].bind(xo, bx) s[AA].bind(xi, tx) xo, xi = s[BB].split(s[BB].op.axis[0], nparts=1) s[BB].bind(xo, bx) s[BB].bind(xi, tx) print(tvm.lower(s, [A, B, C], target, simple_mode=True)) compute = tvm.build(s, [A, B, C], target, name="run") print(compute.imported_modules[0].get_source()) ``` I think the problem is `WarpMemoryRewriter::VisitStmt_(const AllocateNode*)` in `lower_warp_memory.cc` doesn't continue the recursion after rewriting the first buffer. I will fix it.
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
