XFPlus opened a new issue, #14557:
URL: https://github.com/apache/tvm/issues/14557
When I tuned resnet-50 using meta-schedule, I found that conv2d_winograd
implementations raise an error "can not found variable buf_dyn_shmem". After I
deep in, I think it's caused by `MergeDynamicSharedMemoryAllocations` and it
can not work correctly when meets multi blocks in one prim_func. When we pass
in a multi-blocks prim_func like the code shown below, this pass will generate
buf_dyn_shmem allocator in the first statement, and reference it in the later
statements which will lead to an error.
### Steps to reproduce
```py
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
import numpy as np
@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle):
# We exchange data between function by handles, which are similar to
pointer.
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# Create buffer from handles.
A = T.match_buffer(a, (8,), dtype="float32")
B = T.match_buffer(b, (8,), dtype="float32")
C = T.match_buffer(c, (8,), dtype="float32")
# We define two buffers used in block "C"
with T.launch_thread("threadIdx.x", 8) as vi:
# A block is an abstraction for computation.
with T.block("B"):
# Define a spatial block iterator and bind it to value i.
B[vi] = A[vi] + 1.0
with T.launch_thread("threadIdx.x", 4) as tx:
with T.block("C"):
dyn_0 = T.alloc_buffer((16), dtype='float32',
scope='shared.dyn')
dyn_1 = T.alloc_buffer((16), dtype='float32',
scope='shared.dyn')
dyn_2 = T.alloc_buffer((32), dtype='float32',
scope='shared.dyn')
# Define another block to use buf_dyn actually.
dyn_0[tx] = B[tx]
dyn_1[tx] = C[tx]
dyn_2[tx] = dyn_0[tx]
dyn_2[4+tx] = dyn_1[tx]
A[tx] = dyn_2[tx]
A[4+tx] = dyn_2[4+tx]
ir_module = MyModule
sch = tvm.tir.Schedule(ir_module)
print(type(sch))
mod = sch.mod
print(mod)
ctx = tvm.cuda(0)
cuda_mod = tvm.build(mod, target="cuda")
cuda_a = tvm.nd.array(np.arange(8).astype("float32"), ctx)
cuda_b = tvm.nd.array(np.zeros((8,)).astype("float32"), ctx)
cuda_c = tvm.nd.array(np.arange(8).astype("float32"), ctx)
cuda_mod(cuda_a, cuda_b, cuda_c)
print(cuda_a)
print(cuda_b)
print(cuda_c)
```
### Expected behavior
`MergeDynamicSharedMemoryAllocations` can process each statement
independently like what I do now, or maybe `SplitHostDevice` can do more for
this case?
generated irmodule:
```code
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
I.module_attrs({"runtime": None})
@T.prim_func
def main_kernel0(B: T.handle("float32", "global"), A:
T.handle("float32", "global")):
T.func_attr({"calling_conv": 2, "global_symbol": "main_kernel0",
"target": T.target({"arch": "sm_80", "host": {"keys": ["cpu"], "kind": "llvm",
"tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024,
"tag": "", "thread_warp_size": 32}), "tir.
is_global_func": 1, "tir.kernel_launch_params": ["threadIdx.x"],
"tir.noalias": 1})
threadIdx_x = T.launch_thread("threadIdx.x", 8)
B_1 = T.Buffer((8,), data=B)
A_1 = T.Buffer((8,), data=A)
B_1[threadIdx_x] = A_1[threadIdx_x] + T.float32(1)
@T.prim_func
def main_kernel1(B: T.handle("float32", "global"), C:
T.handle("float32", "global"), A: T.handle("float32", "global")):
T.func_attr({"calling_conv": 2, "global_symbol": "main_kernel1",
"target": T.target({"arch": "sm_80", "host": {"keys": ["cpu"], "kind": "llvm",
"tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024,
"tag": "", "thread_warp_size": 32}), "tir.
is_global_func": 1, "tir.kernel_launch_params": ["threadIdx.x",
"tir.use_dyn_shared_memory"], "tir.noalias": 1})
threadIdx_x = T.launch_thread("threadIdx.x", 4)
buf_dyn_shmem = T.allocate([48], "uint8", "shared.dyn")
dyn_0 = T.Buffer((4,), data=buf_dyn_shmem, scope="shared.dyn")
B_1 = T.Buffer((8,), data=B)
dyn_0[threadIdx_x + 4] = B_1[threadIdx_x]
dyn_1 = T.Buffer((4,), data=buf_dyn_shmem, scope="shared.dyn")
C_1 = T.Buffer((8,), data=C)
dyn_1[threadIdx_x] = C_1[threadIdx_x]
dyn_2 = T.Buffer((8,), data=buf_dyn_shmem, scope="shared.dyn")
dyn_2[threadIdx_x + 4] = dyn_0[threadIdx_x + 4]
dyn_2[threadIdx_x + 4 + 4] = dyn_1[threadIdx_x]
A_1 = T.Buffer((8,), data=A)
T.tvm_storage_sync("shared.dyn")
A_1[threadIdx_x] = dyn_2[threadIdx_x + 4]
A_1[threadIdx_x + 4] = dyn_2[threadIdx_x + 4 + 4]
@T.prim_func
def main(args: T.handle, arg_type_ids: T.handle("int32"), num_args:
T.int32, out_ret_value: T.handle("void"), out_ret_tcode: T.handle("int32"),
resource_handle: T.handle) -> T.int32:
T.func_attr({"calling_conv": 1, "global_symbol": "main", "target":
None, "tir.is_entry_func": T.bool(True), "tir.noalias": T.bool(True)})
......
T.call_packed("__tvm_set_device", 2, dev_id)
T.attr(0, "compute_scope", "main_compute_")
T.call_packed("main_kernel0", B, A, 8)
T.call_packed("main_kernel1", B, C, A, 4, 48)
```
### Actual behavior
generated irmodule:
```code
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
I.module_attrs({"runtime": None})
@T.prim_func
def main_kernel0(B: T.handle("float32", "global"), A:
T.handle("float32", "global")):
T.func_attr({"calling_conv": 2, "global_symbol": "main_kernel0",
"target": T.target({"arch": "sm_80", "host": {"keys": ["cpu"], "kind": "llvm",
"tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024,
"tag": "", "thread_warp_size": 32}), "tir.
is_global_func": 1, "tir.kernel_launch_params": ["threadIdx.x",
"tir.use_dyn_shared_memory"], "tir.noalias": 1})
threadIdx_x = T.launch_thread("threadIdx.x", 8)
buf_dyn_shmem = T.allocate([48], "uint8", "shared.dyn")
B_1 = T.Buffer((8,), data=B)
A_1 = T.Buffer((8,), data=A)
B_1[threadIdx_x] = A_1[threadIdx_x] + T.float32(1)
@T.prim_func
def main_kernel1(buf_dyn_shmem: T.handle("uint8", "shared.dyn"), B:
T.handle("float32", "global"), C: T.handle("float32", "global"), A:
T.handle("float32", "global")):
T.func_attr({"calling_conv": 2, "global_symbol": "main_kernel1",
"target": T.target({"arch": "sm_80", "host": {"keys": ["cpu"], "kind": "llvm",
"tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024,
"tag": "", "thread_warp_size": 32}), "tir.
is_global_func": 1, "tir.kernel_launch_params": ["threadIdx.x"],
"tir.noalias": 1})
threadIdx_x = T.launch_thread("threadIdx.x", 4)
dyn_0 = T.Buffer((4,), data=buf_dyn_shmem, scope="shared.dyn")
B_1 = T.Buffer((8,), data=B)
dyn_0[threadIdx_x + 4] = B_1[threadIdx_x]
dyn_1 = T.Buffer((4,), data=buf_dyn_shmem, scope="shared.dyn")
C_1 = T.Buffer((8,), data=C)
dyn_1[threadIdx_x] = C_1[threadIdx_x]
dyn_2 = T.Buffer((8,), data=buf_dyn_shmem, scope="shared.dyn")
dyn_2[threadIdx_x + 4] = dyn_0[threadIdx_x + 4]
dyn_2[threadIdx_x + 4 + 4] = dyn_1[threadIdx_x]
A_1 = T.Buffer((8,), data=A)
T.tvm_storage_sync("shared.dyn")
A_1[threadIdx_x] = dyn_2[threadIdx_x + 4]
A_1[threadIdx_x + 4] = dyn_2[threadIdx_x + 4 + 4]
@T.prim_func
def main(args: T.handle, arg_type_ids: T.handle("int32"), num_args:
T.int32, out_ret_value: T.handle("void"), out_ret_tcode: T.handle("int32"),
resource_handle: T.handle) -> T.int32:
T.func_attr({"calling_conv": 1, "global_symbol": "main", "target":
None, "tir.is_entry_func": T.bool(True), "tir.noalias": T.bool(True)})
......
T.call_packed("__tvm_set_device", 2, dev_id)
T.attr(0, "compute_scope", "main_compute_")
T.call_packed("main_kernel0", B, A, 8, 48)
buf_dyn_shmem = T.handle("uint8", "shared.dyn")
T.call_packed("main_kernel1", buf_dyn_shmem, B, C, A, 4)
```
### Environment
I'm using TVM v0.13.dev0 with commit:
4e07a8ed6687a08b6b27db21af019a5a179b9ee1 on a linux-x86_64 machine.
### Something
And here is my workaround:
```git
diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
index 02cfad3fc..85594fabe 100644
--- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
+++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
@@ -593,6 +600,26 @@ class DynamicSharedMemoryRewriter : public
StmtExprMutator {
support::Arena arena_;
};
+class DynamicSharedMemoryRewriterWrapper : public StmtExprMutator {
+ public:
+ explicit DynamicSharedMemoryRewriterWrapper() {}
+
+ private:
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ if (op->attr_key == attr::thread_extent) {
+ auto stmt = Downcast<Stmt>(StmtMutator::VisitStmt_(op));
+ AllocateCollector collector;
+ collector(stmt);
+ if (collector.dyn_shmem_allocs_.size() > 1) {
+ DynamicSharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_);
+ rewriter.PlanReuse(stmt);
+ return rewriter(std::move(stmt));
+ }
+ }
+};
+
Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) {
AllocateCollector collector;
collector(stmt);
@@ -609,7 +636,8 @@ namespace transform {
Pass MergeDynamicSharedMemoryAllocations() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
- n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body));
+ // n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body));
+ n->body = DynamicSharedMemoryRewriterWrapper()(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0,
"tir.MergeDynamicSharedMemoryAllocations", {});
```
### Triage
* tir::transform
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]