XFPlus opened a new issue, #14557:
URL: https://github.com/apache/tvm/issues/14557

   
   
   
   When I tuned resnet-50 using meta-schedule, I found that conv2d_winograd 
implementations raise an error "can not found variable buf_dyn_shmem". After I 
deep in, I think it's caused by `MergeDynamicSharedMemoryAllocations` and it 
can not work correctly when meets multi blocks in one prim_func. When we pass 
in a multi-blocks prim_func like the code shown below, this pass will generate 
buf_dyn_shmem allocator in the first statement, and reference it in the later 
statements which will lead to an error.
   
   
   ### Steps to reproduce
   ```py
   import tvm
   from tvm.ir.module import IRModule
   from tvm.script import tir as T
   import numpy as np
   
   @tvm.script.ir_module
   class MyModule:
       @T.prim_func
       def main(a: T.handle, b: T.handle, c: T.handle):
           # We exchange data between function by handles, which are similar to 
pointer.
           T.func_attr({"global_symbol": "main", "tir.noalias": True})
           # Create buffer from handles.
           A = T.match_buffer(a, (8,), dtype="float32")
           B = T.match_buffer(b, (8,), dtype="float32")
           C = T.match_buffer(c, (8,), dtype="float32")
   
           # We define two buffers used in block "C"
   
           with T.launch_thread("threadIdx.x", 8) as vi:
               # A block is an abstraction for computation.
               with T.block("B"):
                   # Define a spatial block iterator and bind it to value i.
                   B[vi] = A[vi] + 1.0
   
           with T.launch_thread("threadIdx.x", 4) as tx:
               with T.block("C"):
                   dyn_0 = T.alloc_buffer((16), dtype='float32', 
scope='shared.dyn')
                   dyn_1 = T.alloc_buffer((16), dtype='float32', 
scope='shared.dyn')
                   dyn_2 = T.alloc_buffer((32), dtype='float32', 
scope='shared.dyn')
                   # Define another block to use buf_dyn actually.
                   dyn_0[tx] = B[tx]
                   dyn_1[tx] = C[tx]
                   dyn_2[tx] = dyn_0[tx]
                   dyn_2[4+tx] = dyn_1[tx]
   
                   A[tx] = dyn_2[tx]
                   A[4+tx] = dyn_2[4+tx]
   
   
   
   ir_module = MyModule
   
   sch = tvm.tir.Schedule(ir_module)
   print(type(sch))
   
   mod = sch.mod
   print(mod)
   ctx = tvm.cuda(0)
   cuda_mod = tvm.build(mod, target="cuda")
   cuda_a = tvm.nd.array(np.arange(8).astype("float32"), ctx)
   cuda_b = tvm.nd.array(np.zeros((8,)).astype("float32"), ctx)
   cuda_c = tvm.nd.array(np.arange(8).astype("float32"), ctx)
   cuda_mod(cuda_a, cuda_b, cuda_c)
   print(cuda_a)
   print(cuda_b)
   print(cuda_c)
   ```
   
   
   ### Expected behavior
   `MergeDynamicSharedMemoryAllocations` can process each statement 
independently like what I do now, or maybe `SplitHostDevice` can do more for 
this case?
   
   generated irmodule:
   ```code
   # from tvm.script import ir as I
   # from tvm.script import tir as T
   
   @I.ir_module
   class Module:
       I.module_attrs({"runtime": None})
       @T.prim_func
       def main_kernel0(B: T.handle("float32", "global"), A: 
T.handle("float32", "global")):
           T.func_attr({"calling_conv": 2, "global_symbol": "main_kernel0", 
"target": T.target({"arch": "sm_80", "host": {"keys": ["cpu"], "kind": "llvm", 
"tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, 
"tag": "", "thread_warp_size": 32}), "tir.
   is_global_func": 1, "tir.kernel_launch_params": ["threadIdx.x"], 
"tir.noalias": 1})
           threadIdx_x = T.launch_thread("threadIdx.x", 8)
           B_1 = T.Buffer((8,), data=B)
           A_1 = T.Buffer((8,), data=A)
           B_1[threadIdx_x] = A_1[threadIdx_x] + T.float32(1)
   
       @T.prim_func
       def main_kernel1(B: T.handle("float32", "global"), C: 
T.handle("float32", "global"), A: T.handle("float32", "global")):
           T.func_attr({"calling_conv": 2, "global_symbol": "main_kernel1", 
"target": T.target({"arch": "sm_80", "host": {"keys": ["cpu"], "kind": "llvm", 
"tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, 
"tag": "", "thread_warp_size": 32}), "tir.
   is_global_func": 1, "tir.kernel_launch_params": ["threadIdx.x", 
"tir.use_dyn_shared_memory"], "tir.noalias": 1})
           threadIdx_x = T.launch_thread("threadIdx.x", 4)
           buf_dyn_shmem = T.allocate([48], "uint8", "shared.dyn")
           dyn_0 = T.Buffer((4,), data=buf_dyn_shmem, scope="shared.dyn")
           B_1 = T.Buffer((8,), data=B)
           dyn_0[threadIdx_x + 4] = B_1[threadIdx_x]
           dyn_1 = T.Buffer((4,), data=buf_dyn_shmem, scope="shared.dyn")
           C_1 = T.Buffer((8,), data=C)
           dyn_1[threadIdx_x] = C_1[threadIdx_x]
           dyn_2 = T.Buffer((8,), data=buf_dyn_shmem, scope="shared.dyn")
           dyn_2[threadIdx_x + 4] = dyn_0[threadIdx_x + 4]
           dyn_2[threadIdx_x + 4 + 4] = dyn_1[threadIdx_x]
           A_1 = T.Buffer((8,), data=A)
           T.tvm_storage_sync("shared.dyn")
           A_1[threadIdx_x] = dyn_2[threadIdx_x + 4]
           A_1[threadIdx_x + 4] = dyn_2[threadIdx_x + 4 + 4]
   
   @T.prim_func
       def main(args: T.handle, arg_type_ids: T.handle("int32"), num_args: 
T.int32, out_ret_value: T.handle("void"), out_ret_tcode: T.handle("int32"), 
resource_handle: T.handle) -> T.int32:
           T.func_attr({"calling_conv": 1, "global_symbol": "main", "target": 
None, "tir.is_entry_func": T.bool(True), "tir.noalias": T.bool(True)})
           ......
           T.call_packed("__tvm_set_device", 2, dev_id)
           T.attr(0, "compute_scope", "main_compute_")
           T.call_packed("main_kernel0", B, A, 8)
           T.call_packed("main_kernel1", B, C, A, 4, 48)
   ```
   
   ### Actual behavior
   
   generated irmodule:
   ```code
   # from tvm.script import ir as I
   # from tvm.script import tir as T
   
   @I.ir_module
   class Module:
       I.module_attrs({"runtime": None})
       @T.prim_func
       def main_kernel0(B: T.handle("float32", "global"), A: 
T.handle("float32", "global")):
           T.func_attr({"calling_conv": 2, "global_symbol": "main_kernel0", 
"target": T.target({"arch": "sm_80", "host": {"keys": ["cpu"], "kind": "llvm", 
"tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, 
"tag": "", "thread_warp_size": 32}), "tir.
   is_global_func": 1, "tir.kernel_launch_params": ["threadIdx.x", 
"tir.use_dyn_shared_memory"], "tir.noalias": 1})
           threadIdx_x = T.launch_thread("threadIdx.x", 8)
           buf_dyn_shmem = T.allocate([48], "uint8", "shared.dyn")
           B_1 = T.Buffer((8,), data=B)
           A_1 = T.Buffer((8,), data=A)
           B_1[threadIdx_x] = A_1[threadIdx_x] + T.float32(1)
   
       @T.prim_func
       def main_kernel1(buf_dyn_shmem: T.handle("uint8", "shared.dyn"), B: 
T.handle("float32", "global"), C: T.handle("float32", "global"), A: 
T.handle("float32", "global")):
           T.func_attr({"calling_conv": 2, "global_symbol": "main_kernel1", 
"target": T.target({"arch": "sm_80", "host": {"keys": ["cpu"], "kind": "llvm", 
"tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, 
"tag": "", "thread_warp_size": 32}), "tir.
   is_global_func": 1, "tir.kernel_launch_params": ["threadIdx.x"], 
"tir.noalias": 1})
           threadIdx_x = T.launch_thread("threadIdx.x", 4)
           dyn_0 = T.Buffer((4,), data=buf_dyn_shmem, scope="shared.dyn")
           B_1 = T.Buffer((8,), data=B)
           dyn_0[threadIdx_x + 4] = B_1[threadIdx_x]
           dyn_1 = T.Buffer((4,), data=buf_dyn_shmem, scope="shared.dyn")
           C_1 = T.Buffer((8,), data=C)
           dyn_1[threadIdx_x] = C_1[threadIdx_x]
           dyn_2 = T.Buffer((8,), data=buf_dyn_shmem, scope="shared.dyn")
           dyn_2[threadIdx_x + 4] = dyn_0[threadIdx_x + 4]
           dyn_2[threadIdx_x + 4 + 4] = dyn_1[threadIdx_x]
           A_1 = T.Buffer((8,), data=A)
           T.tvm_storage_sync("shared.dyn")
           A_1[threadIdx_x] = dyn_2[threadIdx_x + 4]
           A_1[threadIdx_x + 4] = dyn_2[threadIdx_x + 4 + 4]
   
       @T.prim_func
       def main(args: T.handle, arg_type_ids: T.handle("int32"), num_args: 
T.int32, out_ret_value: T.handle("void"), out_ret_tcode: T.handle("int32"), 
resource_handle: T.handle) -> T.int32:
           T.func_attr({"calling_conv": 1, "global_symbol": "main", "target": 
None, "tir.is_entry_func": T.bool(True), "tir.noalias": T.bool(True)})
           ......
           T.call_packed("__tvm_set_device", 2, dev_id)
           T.attr(0, "compute_scope", "main_compute_")
           T.call_packed("main_kernel0", B, A, 8, 48)
           buf_dyn_shmem = T.handle("uint8", "shared.dyn")
           T.call_packed("main_kernel1", buf_dyn_shmem, B, C, A, 4)
   ```
   
   ### Environment
   
   I'm using TVM v0.13.dev0 with commit: 
4e07a8ed6687a08b6b27db21af019a5a179b9ee1 on a linux-x86_64 machine.
   
   ### Something
   And here is my workaround:
   ```git
   diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc 
b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
   index 02cfad3fc..85594fabe 100644
   --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
   +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
   
   @@ -593,6 +600,26 @@ class DynamicSharedMemoryRewriter : public 
StmtExprMutator {
      support::Arena arena_;
    };
   
   +class DynamicSharedMemoryRewriterWrapper : public StmtExprMutator {
   + public:
   +  explicit DynamicSharedMemoryRewriterWrapper() {}
   +
   + private:
   +  Stmt VisitStmt_(const AttrStmtNode* op) final {
   +      if (op->attr_key == attr::thread_extent) {
   +        auto stmt = Downcast<Stmt>(StmtMutator::VisitStmt_(op));
   +        AllocateCollector collector;
   +        collector(stmt);
   +        if (collector.dyn_shmem_allocs_.size() > 1) {
   +          DynamicSharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_);
   +          rewriter.PlanReuse(stmt);
   +          return rewriter(std::move(stmt));
   +        }
   +  }
   +};
   +
    Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) {
      AllocateCollector collector;
      collector(stmt);
   @@ -609,7 +636,8 @@ namespace transform {
    Pass MergeDynamicSharedMemoryAllocations() {
      auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
        auto* n = f.CopyOnWrite();
   -    n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body));
   +    // n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body));
   +    n->body = DynamicSharedMemoryRewriterWrapper()(std::move(n->body));
        return f;
      };
      return CreatePrimFuncPass(pass_func, 0, 
"tir.MergeDynamicSharedMemoryAllocations", {});
   ```
   
   ### Triage
   
   
   * tir::transform
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to