This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e95f10f  [CUDA] Initial support for dynamic shared memory (#8466)
e95f10f is described below

commit e95f10f8fdc30a02ae004dc8a7aad72538b38d08
Author: masahi <[email protected]>
AuthorDate: Thu Jul 22 14:26:49 2021 +0900

    [CUDA] Initial support for dynamic shared memory (#8466)
    
    * send dyn shmem size to runtime
    
    * add dyn shared storage scope
    
    * associate buffer var and its storage scoe in split_host_device
    
    * tried NVPTX but failed with INVALID_PTX error
    
    * test stub
    
    * dynamic shmem reduce working
    
    * log2 issue fixed
    
    * nvptx working
    
    * refactor llvm shmem allocation
    
    * make linkage argument
    
    * support rocm too
    
    * send dyn shmem param to hip runtime
    
    * remove alloc map from split_host_device.cc
    
    * remove attr::storage_scope from split_host_device
    
    * lint fix
    
    * formatting
    
    * update calling convention doc
    
    * minor update to test
    
    * remove log
    
    * remove kDynShared, dyn.shared -> shared.dyn
    
    * support backward compat
    
    * update json/binary reader/writer
    
    * thread_axis_tags -> launch_param_tags
    
    * ThreadAxisConfig -> LaunchParamConfig
    
    * remove use_dynamic_shared_memory from FunctionInfo meta data
    
    * revert change in test_tir_ir_builder.py
    
    * make sure kUseDynamicSharedMemoryTag is the last tag
    
    * remove continue
    
    * update doc string following name change
    
    * more comment update following name change
    
    Co-authored-by: masa <[email protected]>
    Co-authored-by: Masahiro Masuda <masahi@[email protected]>
---
 docs/dev/codebase_walkthrough.rst                  |  4 +-
 include/tvm/tir/function.h                         | 11 +++-
 src/runtime/cuda/cuda_module.cc                    | 14 ++---
 src/runtime/file_utils.cc                          | 10 ++--
 src/runtime/meta_data.h                            |  5 +-
 src/runtime/metal/metal_module.mm                  | 12 ++--
 src/runtime/opencl/opencl_module.cc                | 14 ++---
 src/runtime/rocm/rocm_module.cc                    | 19 +++---
 src/runtime/thread_storage_scope.h                 | 33 +++++++---
 src/runtime/vulkan/vulkan_wrapped_func.cc          |  8 +--
 src/runtime/vulkan/vulkan_wrapped_func.h           |  7 +--
 src/target/build_common.h                          |  7 ++-
 src/target/llvm/codegen_amdgpu.cc                  | 70 ++++++++++------------
 src/target/llvm/codegen_llvm.cc                    | 16 +++++
 src/target/llvm/codegen_llvm.h                     |  5 ++
 src/target/llvm/codegen_nvptx.cc                   | 60 +++++++++----------
 src/target/source/codegen_cuda.cc                  | 28 ++++++---
 .../transforms/lower_device_storage_access_info.cc |  2 +-
 src/tir/transforms/split_host_device.cc            | 23 +++++++
 src/tir/transforms/storage_rewrite.cc              |  6 +-
 tests/python/unittest/test_tir_ir_builder.py       | 58 ++++++++++++++++++
 web/src/webgpu.ts                                  |  6 +-
 22 files changed, 277 insertions(+), 141 deletions(-)

diff --git a/docs/dev/codebase_walkthrough.rst 
b/docs/dev/codebase_walkthrough.rst
index 60ab5e5..efc8b32 100644
--- a/docs/dev/codebase_walkthrough.rst
+++ b/docs/dev/codebase_walkthrough.rst
@@ -183,7 +183,7 @@ The first time you invoke the compiled module with 
``fadd(a, b, c)``, ``GetFunct
      auto it = fmap_.find(name);
      const FunctionInfo& info = it->second;
      CUDAWrappedFunc f;
-     f.Init(this, sptr_to_self, name, info.arg_types.size(), 
info.thread_axis_tags);
+     f.Init(this, sptr_to_self, name, info.arg_types.size(), 
info.launch_param_tags);
      return PackFuncVoidAddr(f, info.arg_types);
    }
 
@@ -204,7 +204,7 @@ The ``PackedFunc``'s overloaded ``operator()`` will be 
called, which in turn cal
          fcache_[device_id] = m_->GetFunc(device_id, func_name_);
        }
        CUstream strm = 
static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
-       ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
+       ThreadWorkLoad wl = launch_param_config_.Extract(args);
        CUresult result = cuLaunchKernel(
            fcache_[device_id],
            wl.grid_dim(0),
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index 25ed2f9..55f4fc6 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -240,10 +240,12 @@ namespace attr {
  *
  * Call(f,
  *      [arg1, arg2, ..., arg_n,
- *       work_size_1, work_size_2, ... work_size_m])
+ *       work_size_1, work_size_2, ... work_size_m, dyn_shmem_size])
  *
  * Here n = len(arg), m = len(work_size) = len(device_thread_axis).
  *
+ * When kDeviceUseDynSharedMemory is not set, dyn_shmem_size argument is 
omitted.
+ *
  * The list of device_thread_axis indicates how can be bind the
  * work_size arguments to the corresponding threads.
  *
@@ -252,6 +254,13 @@ namespace attr {
 constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis";
 
 /*!
+ * \brief Whether or not use dynamic shared memory.
+ *
+ * Type: Integer
+ */
+constexpr const char* kDeviceUseDynSharedMemory = 
"tir.device_use_dyn_shared_memory";
+
+/*!
  * \brief Whether to set noalias rule on the function arguments.
  *
  * Type: Integer
diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc
index a877bc6..7d6879a 100644
--- a/src/runtime/cuda/cuda_module.cc
+++ b/src/runtime/cuda/cuda_module.cc
@@ -153,12 +153,12 @@ class CUDAWrappedFunc {
  public:
   // initialize the CUDA function.
   void Init(CUDAModuleNode* m, ObjectPtr<Object> sptr, const std::string& 
func_name,
-            size_t num_void_args, const std::vector<std::string>& 
thread_axis_tags) {
+            size_t num_void_args, const std::vector<std::string>& 
launch_param_tags) {
     m_ = m;
     sptr_ = sptr;
     func_name_ = func_name;
     std::fill(fcache_.begin(), fcache_.end(), nullptr);
-    thread_axis_cfg_.Init(num_void_args, thread_axis_tags);
+    launch_param_config_.Init(num_void_args, launch_param_tags);
   }
   // invoke the function with void arguments
   void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
@@ -168,10 +168,10 @@ class CUDAWrappedFunc {
       fcache_[device_id] = m_->GetFunc(device_id, func_name_);
     }
     CUstream strm = 
static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
-    ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
+    ThreadWorkLoad wl = launch_param_config_.Extract(args);
     CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), 
wl.grid_dim(1),
                                      wl.grid_dim(2), wl.block_dim(0), 
wl.block_dim(1),
-                                     wl.block_dim(2), 0, strm, void_args, 
nullptr);
+                                     wl.block_dim(2), wl.dyn_shmem_size, strm, 
void_args, nullptr);
     if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) {
       const char* msg;
       cuGetErrorName(result, &msg);
@@ -201,8 +201,8 @@ class CUDAWrappedFunc {
   // Device function cache per device.
   // mark as mutable, to enable lazy initialization
   mutable std::array<CUfunction, kMaxNumGPUs> fcache_;
-  // thread axis configuration
-  ThreadAxisConfig thread_axis_cfg_;
+  // launch parameters configuration
+  LaunchParamConfig launch_param_config_;
 };
 
 class CUDAPrepGlobalBarrier {
@@ -241,7 +241,7 @@ PackedFunc CUDAModuleNode::GetFunction(const std::string& 
name,
   if (it == fmap_.end()) return PackedFunc();
   const FunctionInfo& info = it->second;
   CUDAWrappedFunc f;
-  f.Init(this, sptr_to_self, name, info.arg_types.size(), 
info.thread_axis_tags);
+  f.Init(this, sptr_to_self, name, info.arg_types.size(), 
info.launch_param_tags);
   return PackFuncVoidAddr(f, info.arg_types);
 }
 
diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc
index 32dd1d8..35832e8 100644
--- a/src/runtime/file_utils.cc
+++ b/src/runtime/file_utils.cc
@@ -43,7 +43,7 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
   writer->BeginObject();
   writer->WriteObjectKeyValue("name", name);
   writer->WriteObjectKeyValue("arg_types", sarg_types);
-  writer->WriteObjectKeyValue("thread_axis_tags", thread_axis_tags);
+  writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags);
   writer->EndObject();
 }
 
@@ -52,7 +52,9 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
   std::vector<std::string> sarg_types;
   helper.DeclareField("name", &name);
   helper.DeclareField("arg_types", &sarg_types);
-  helper.DeclareField("thread_axis_tags", &thread_axis_tags);
+  helper.DeclareOptionalField("launch_param_tags", &launch_param_tags);
+  helper.DeclareOptionalField("thread_axis_tags",
+                              &launch_param_tags);  // for backward 
compatibility
   helper.ReadAllFields(reader);
   arg_types.resize(sarg_types.size());
   for (size_t i = 0; i < arg_types.size(); ++i) {
@@ -63,13 +65,13 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
 void FunctionInfo::Save(dmlc::Stream* writer) const {
   writer->Write(name);
   writer->Write(arg_types);
-  writer->Write(thread_axis_tags);
+  writer->Write(launch_param_tags);
 }
 
 bool FunctionInfo::Load(dmlc::Stream* reader) {
   if (!reader->Read(&name)) return false;
   if (!reader->Read(&arg_types)) return false;
-  if (!reader->Read(&thread_axis_tags)) return false;
+  if (!reader->Read(&launch_param_tags)) return false;
   return true;
 }
 
diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h
index e3ec155..002012a 100644
--- a/src/runtime/meta_data.h
+++ b/src/runtime/meta_data.h
@@ -99,11 +99,14 @@ Module MetadataModuleCreate(
     const std::unordered_map<std::string, NDArray>& metadata,
     const std::unordered_map<std::string, std::vector<std::string>>& sym_vars);
 
+/*! \brief A tag to specify whether or not dynamic shared memory is used */
+constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory";
+
 /*! \brief function information needed by device */
 struct FunctionInfo {
   std::string name;
   std::vector<DLDataType> arg_types;
-  std::vector<std::string> thread_axis_tags;
+  std::vector<std::string> launch_param_tags;
 
   void Save(dmlc::JSONWriter* writer) const;
   void Load(dmlc::JSONReader* reader);
diff --git a/src/runtime/metal/metal_module.mm 
b/src/runtime/metal/metal_module.mm
index 8850188..1e81ac1 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -178,7 +178,7 @@ class MetalWrappedFunc {
   // initialize the METAL function.
   void Init(MetalModuleNode* m, ObjectPtr<Object> sptr, const std::string& 
func_name,
             size_t num_buffer_args, size_t num_pack_args,
-            const std::vector<std::string>& thread_axis_tags) {
+            const std::vector<std::string>& launch_param_tags) {
     w_ = metal::MetalWorkspace::Global();
     m_ = m;
     sptr_ = sptr;
@@ -186,7 +186,7 @@ class MetalWrappedFunc {
     num_buffer_args_ = num_buffer_args;
     num_pack_args_ = num_pack_args;
     std::fill(scache_.begin(), scache_.end(), 
(id<MTLComputePipelineState>)nil);
-    thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
+    launch_param_config_.Init(num_buffer_args + num_pack_args, 
launch_param_tags);
     metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
     int dev_id = t->device.device_id;
     scache_[dev_id] = m->GetPipelineState(dev_id, func_name);
@@ -201,7 +201,7 @@ class MetalWrappedFunc {
       if (scache_[device_id] == nil) {
         scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
       }
-      ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
+      ThreadWorkLoad wl = launch_param_config_.Extract(args);
       int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
       auto maxTotalThreadsPerThreadgroup = 
scache_[device_id].maxTotalThreadsPerThreadgroup;
       CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup);
@@ -242,8 +242,8 @@ class MetalWrappedFunc {
   // Device state cache per device.
   // mark as mutable, to enable lazy initialization
   mutable std::array<id<MTLComputePipelineState>, kMetalMaxNumDevice> scache_;
-  // thread axis configuration
-  ThreadAxisConfig thread_axis_cfg_;
+  // launch parameters configuration
+  LaunchParamConfig launch_param_config_;
 };
 
 PackedFunc MetalModuleNode::GetFunction(const std::string& name,
@@ -261,7 +261,7 @@ PackedFunc MetalModuleNode::GetFunction(const std::string& 
name,
     MetalWrappedFunc f;
     size_t num_buffer_args = NumBufferArgs(info.arg_types);
     f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - 
num_buffer_args,
-           info.thread_axis_tags);
+           info.launch_param_tags);
     pf = PackFuncNonBufferArg(f, info.arg_types);
   };
   return pf;
diff --git a/src/runtime/opencl/opencl_module.cc 
b/src/runtime/opencl/opencl_module.cc
index 4040d82..f6c7f62 100644
--- a/src/runtime/opencl/opencl_module.cc
+++ b/src/runtime/opencl/opencl_module.cc
@@ -40,14 +40,14 @@ class OpenCLWrappedFunc {
   // initialize the OpenCL function.
   void Init(OpenCLModuleNode* m, ObjectPtr<Object> sptr, 
OpenCLModuleNode::KTRefEntry entry,
             std::string func_name, std::vector<size_t> arg_size,
-            const std::vector<std::string>& thread_axis_tags) {
+            const std::vector<std::string>& launch_param_tags) {
     w_ = m->GetGlobalWorkspace();
     m_ = m;
     sptr_ = sptr;
     entry_ = entry;
     func_name_ = func_name;
     arg_size_ = arg_size;
-    thread_axis_cfg_.Init(arg_size.size(), thread_axis_tags);
+    launch_param_config_.Init(arg_size.size(), launch_param_tags);
   }
   // invoke the function with void arguments
   void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
@@ -73,8 +73,8 @@ class OpenCLWrappedFunc {
       OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], arg));
     }
     cl_command_queue queue = w_->GetQueue(t->device);
-    ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
-    cl_uint work_dim = static_cast<cl_uint>(thread_axis_cfg_.work_dim());
+    ThreadWorkLoad wl = launch_param_config_.Extract(args);
+    cl_uint work_dim = static_cast<cl_uint>(launch_param_config_.work_dim());
     for (cl_uint i = 0; i < work_dim; ++i) {
       wl.work_size[i] *= wl.work_size[i + 3];
     }
@@ -96,8 +96,8 @@ class OpenCLWrappedFunc {
   std::string func_name_;
   // convert code for void argument
   std::vector<size_t> arg_size_;
-  // thread axis config
-  ThreadAxisConfig thread_axis_cfg_;
+  // launch parameters config
+  LaunchParamConfig launch_param_config_;
 };
 
 OpenCLModuleNode::~OpenCLModuleNode() {
@@ -148,7 +148,7 @@ PackedFunc OpenCLModuleNode::GetFunction(const std::string& 
name,
     }
   }
   // initialize the wrapped func.
-  f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, 
info.thread_axis_tags);
+  f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, 
info.launch_param_tags);
   return PackFuncVoidAddr(f, info.arg_types);
 }
 
diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc
index 567557c..487ad23 100644
--- a/src/runtime/rocm/rocm_module.cc
+++ b/src/runtime/rocm/rocm_module.cc
@@ -147,12 +147,12 @@ class ROCMWrappedFunc {
  public:
   // initialize the ROCM function.
   void Init(ROCMModuleNode* m, ObjectPtr<Object> sptr, const std::string& 
func_name,
-            size_t num_void_args, const std::vector<std::string>& 
thread_axis_tags) {
+            size_t num_void_args, const std::vector<std::string>& 
launch_param_tags) {
     m_ = m;
     sptr_ = sptr;
     func_name_ = func_name;
     std::fill(fcache_.begin(), fcache_.end(), nullptr);
-    thread_axis_cfg_.Init(num_void_args, thread_axis_tags);
+    launch_param_config_.Init(num_void_args, launch_param_tags);
   }
   // invoke the function with void arguments
   void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t 
packed_nbytes) const {
@@ -164,13 +164,14 @@ class ROCMWrappedFunc {
 
     hipStream_t strm = 
static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);
 
-    ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
+    ThreadWorkLoad wl = launch_param_config_.Extract(args);
     void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, 
HIP_LAUNCH_PARAM_BUFFER_SIZE,
                       &packed_nbytes, HIP_LAUNCH_PARAM_END};
     // HIP supports only extra_args.
-    ROCM_DRIVER_CALL(hipModuleLaunchKernel(
-        fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), 
wl.block_dim(0),
-        wl.block_dim(1), wl.block_dim(2), 0, strm, nullptr, 
reinterpret_cast<void**>(&config)));
+    ROCM_DRIVER_CALL(hipModuleLaunchKernel(fcache_[device_id], wl.grid_dim(0), 
wl.grid_dim(1),
+                                           wl.grid_dim(2), wl.block_dim(0), 
wl.block_dim(1),
+                                           wl.block_dim(2), wl.dyn_shmem_size, 
strm, nullptr,
+                                           reinterpret_cast<void**>(&config)));
   }
 
  private:
@@ -183,8 +184,8 @@ class ROCMWrappedFunc {
   // Device function cache per device.
   // mark as mutable, to enable lazy initialization
   mutable std::array<hipFunction_t, kMaxNumGPUs> fcache_;
-  // thread axis configuration
-  ThreadAxisConfig thread_axis_cfg_;
+  // launch parameters configuration
+  LaunchParamConfig launch_param_config_;
 };
 
 PackedFunc ROCMModuleNode::GetFunction(const std::string& name,
@@ -195,7 +196,7 @@ PackedFunc ROCMModuleNode::GetFunction(const std::string& 
name,
   if (it == fmap_.end()) return PackedFunc();
   const FunctionInfo& info = it->second;
   ROCMWrappedFunc f;
-  f.Init(this, sptr_to_self, name, info.arg_types.size(), 
info.thread_axis_tags);
+  f.Init(this, sptr_to_self, name, info.arg_types.size(), 
info.launch_param_tags);
   return PackFuncPackedArg(f, info.arg_types);
 }
 
diff --git a/src/runtime/thread_storage_scope.h 
b/src/runtime/thread_storage_scope.h
index 9d140ae..ac8260f 100644
--- a/src/runtime/thread_storage_scope.h
+++ b/src/runtime/thread_storage_scope.h
@@ -19,7 +19,7 @@
 
 /*!
  * \file thread_storage_scope.h
- * \brief Extract thread axis configuration from TVMArgs.
+ * \brief Extract launch parameters configuration from TVMArgs.
  */
 #ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_
 #define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_
@@ -29,6 +29,8 @@
 #include <string>
 #include <vector>
 
+#include "meta_data.h"
+
 namespace tvm {
 namespace runtime {
 
@@ -182,6 +184,8 @@ struct ThreadScope {
 struct ThreadWorkLoad {
   // array, first three are thread configuration.
   size_t work_size[6];
+  // Dynamic shared memory allocation size in bytes.
+  size_t dyn_shmem_size{0};
   /*!
    * \param i The block dimension.
    * \return i-th block dim
@@ -193,17 +197,23 @@ struct ThreadWorkLoad {
    */
   inline size_t grid_dim(size_t i) const { return work_size[i]; }
 };
-/*! \brief Thread axis configuration */
-class ThreadAxisConfig {
+/*! \brief Launch parameters configuration */
+class LaunchParamConfig {
  public:
-  void Init(size_t base, const std::vector<std::string>& thread_axis_tags) {
+  void Init(size_t base, const std::vector<std::string>& launch_param_tags) {
     base_ = base;
     std::vector<bool> filled(6, false);
-    for (size_t i = 0; i < thread_axis_tags.size(); ++i) {
-      const std::string& tag = thread_axis_tags[i];
-      ThreadScope ts = ThreadScope::Create(tag);
-      arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
-      filled[ts.rank * 3 + ts.dim_index] = true;
+    for (size_t i = 0; i < launch_param_tags.size(); ++i) {
+      const std::string& tag = launch_param_tags[i];
+      if (tag == kUseDynamicSharedMemoryTag) {
+        ICHECK_EQ(i, launch_param_tags.size() - 1)
+            << "kUseDynamicSharedMemoryTag should be the last tag in 
launch_param_tags.";
+        use_dyn_shared_memory_ = true;
+      } else {
+        ThreadScope ts = ThreadScope::Create(tag);
+        arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
+        filled[ts.rank * 3 + ts.dim_index] = true;
+      }
     }
     work_dim_ = 1;
     for (int i = 0; i < 3; ++i) {
@@ -223,6 +233,9 @@ class ThreadAxisConfig {
         w.work_size[arg_index_map_[i]] = size;
       }
     }
+    if (use_dyn_shared_memory_) {
+      w.dyn_shmem_size = static_cast<size_t>(x.values[base_ + 
arg_index_map_.size()].v_int64);
+    }
     return w;
   }
   // return the work dim
@@ -235,6 +248,8 @@ class ThreadAxisConfig {
   size_t work_dim_;
   /*! \brief The index mapping. */
   std::vector<uint32_t> arg_index_map_;
+  /*! \brief Whether or not use dynamic shared memory. */
+  bool use_dyn_shared_memory_{false};
 };
 
 }  // namespace runtime
diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc 
b/src/runtime/vulkan/vulkan_wrapped_func.cc
index 103b2aa..0712f72 100644
--- a/src/runtime/vulkan/vulkan_wrapped_func.cc
+++ b/src/runtime/vulkan/vulkan_wrapped_func.cc
@@ -33,13 +33,13 @@ namespace vulkan {
 void VulkanWrappedFunc::Init(VulkanModuleNode* m, ObjectPtr<Object> sptr,
                              const std::string& func_name, size_t 
num_buffer_args,
                              size_t num_pack_args,
-                             const std::vector<std::string>& thread_axis_tags) 
{
+                             const std::vector<std::string>& 
launch_param_tags) {
   m_ = m;
   sptr_ = sptr;
   func_name_ = func_name;
   num_buffer_args_ = num_buffer_args;
   num_pack_args_ = num_pack_args;
-  thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
+  launch_param_config_.Init(num_buffer_args + num_pack_args, 
launch_param_tags);
 }
 
 void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
@@ -50,7 +50,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* 
rv,
     scache_[device_id] = m_->GetPipeline(device_id, func_name_, 
num_pack_args_);
   }
   const auto& pipeline = scache_[device_id];
-  ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
+  ThreadWorkLoad wl = launch_param_config_.Extract(args);
   std::vector<VkDescriptorBufferInfo> descriptor_buffers;
   descriptor_buffers.resize(num_buffer_args_);
   for (size_t i = 0; i < num_buffer_args_; ++i) {
@@ -197,7 +197,7 @@ PackedFunc VulkanModuleNode::GetFunction(const std::string& 
name,
   VulkanWrappedFunc f;
   size_t num_buffer_args = NumBufferArgs(info.arg_types);
   f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - 
num_buffer_args,
-         info.thread_axis_tags);
+         info.launch_param_tags);
   return PackFuncNonBufferArg(std::move(f), info.arg_types);
 }
 
diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h 
b/src/runtime/vulkan/vulkan_wrapped_func.h
index a174f22..cd4774b 100644
--- a/src/runtime/vulkan/vulkan_wrapped_func.h
+++ b/src/runtime/vulkan/vulkan_wrapped_func.h
@@ -58,7 +58,7 @@ class VulkanWrappedFunc {
  public:
   void Init(VulkanModuleNode* m, ObjectPtr<Object> sptr, const std::string& 
func_name,
             size_t num_buffer_args, size_t num_pack_args,
-            const std::vector<std::string>& thread_axis_tags);
+            const std::vector<std::string>& launch_param_tags);
 
   void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) 
const;
 
@@ -73,11 +73,10 @@ class VulkanWrappedFunc {
   size_t num_buffer_args_;
   // number of packed arguments.
   size_t num_pack_args_;
+  // launch parameters configuration
+  LaunchParamConfig launch_param_config_;
   // Device state cache per device.
   // mark as mutable, to enable lazy initialization
-  // thread axis configuration
-  ThreadAxisConfig thread_axis_cfg_;
-
   mutable std::array<std::shared_ptr<VulkanPipeline>, kVulkanMaxNumDevice> 
scache_;
 };
 
diff --git a/src/target/build_common.h b/src/target/build_common.h
index d2fe646..c66c2b5 100644
--- a/src/target/build_common.h
+++ b/src/target/build_common.h
@@ -53,7 +53,12 @@ inline std::unordered_map<std::string, 
runtime::FunctionInfo> ExtractFuncInfo(co
     if (auto opt = 
f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis)) {
       auto thread_axis = opt.value();
       for (size_t i = 0; i < thread_axis.size(); ++i) {
-        info.thread_axis_tags.push_back(thread_axis[i]->thread_tag);
+        info.launch_param_tags.push_back(thread_axis[i]->thread_tag);
+      }
+    }
+    if (auto opt = f->GetAttr<Integer>(tir::attr::kDeviceUseDynSharedMemory)) {
+      if (opt.value()) {
+        info.launch_param_tags.push_back(runtime::kUseDynamicSharedMemoryTag);
       }
     }
     auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
diff --git a/src/target/llvm/codegen_amdgpu.cc 
b/src/target/llvm/codegen_amdgpu.cc
index 9aec8f4..7770e42 100644
--- a/src/target/llvm/codegen_amdgpu.cc
+++ b/src/target/llvm/codegen_amdgpu.cc
@@ -72,51 +72,45 @@ class CodeGenAMDGPU : public CodeGenLLVM {
   void VisitStmt_(const AllocateNode* op) final {
     ICHECK(!is_zero(op->condition));
     llvm::Value* buf = nullptr;
-
-    int32_t constant_size = op->constant_allocation_size();
-    ICHECK_GT(constant_size, 0) << "Can only handle constant size stack 
allocation in GPU";
-
     StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
-    if (constant_size % 4 == 0 && info.alignment == 0) {
-      info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
-    }
-    // maximum necessary alignment in the AMD devices
-    if (info.alignment > 16) {
-      info.alignment = 16;
-    }
     auto storage_scope = 
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
-    if (storage_scope.rank == runtime::StorageRank::kLocal) {
-      // const int local_address_space = 5;
-      // TODO(tqchen): for higher version of LLVM, local address space can be 
set.
-      llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
-        return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), 
ConstInt32(constant_size));
-      });
-      if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
-#if TVM_LLVM_VERSION >= 100
-        alloca->setAlignment(llvm::Align(info.alignment));
-#else
-        alloca->setAlignment(info.alignment);
-#endif
-      }
-      buf = alloca;
+
+    if (storage_scope.rank == runtime::StorageRank::kShared && 
storage_scope.tag == ".dyn") {
+      LOG(WARNING) << "Dynamic shared memory support for rocm is 
experimental.";
+      buf = AllocateSharedMemory(op->dtype, 0, 3, std::min(info.alignment, 16),
+                                 llvm::GlobalValue::ExternalLinkage);
     } else {
-      ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
-          << "Can only allocate shared or local memory inside kernel";
-      // Shared memory: address space  == 3
-      const unsigned shared_address_space = 3;
-      llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), 
constant_size);
-      // Allocate shared memory in global, address_space = 3
-      llvm::GlobalVariable* global = new llvm::GlobalVariable(
-          *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, 
".shared", nullptr,
-          llvm::GlobalValue::NotThreadLocal, shared_address_space);
-      if (global->getAlignment() < static_cast<uint32_t>(info.alignment)) {
+      int32_t constant_size = op->constant_allocation_size();
+      ICHECK_GT(constant_size, 0) << "Can only handle constant size stack 
allocation in GPU";
+
+      if (constant_size % 4 == 0 && info.alignment == 0) {
+        info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
+      }
+      // maximum necessary alignment in the AMD devices
+      if (info.alignment > 16) {
+        info.alignment = 16;
+      }
+      if (storage_scope.rank == runtime::StorageRank::kLocal) {
+        // const int local_address_space = 5;
+        // TODO(tqchen): for higher version of LLVM, local address space can 
be set.
+        llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
+          return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), 
ConstInt32(constant_size));
+        });
+        if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
 #if TVM_LLVM_VERSION >= 100
-        global->setAlignment(llvm::Align(info.alignment));
+          alloca->setAlignment(llvm::Align(info.alignment));
 #else
-        global->setAlignment(info.alignment);
+          alloca->setAlignment(info.alignment);
 #endif
+        }
+        buf = alloca;
+      } else {
+        ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
+            << "Can only allocate shared or local memory inside kernel";
+        // Shared memory: address space  == 3
+        buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment,
+                                   llvm::GlobalValue::PrivateLinkage);
       }
-      buf = global;
     }
 
     buf = builder_->CreatePointerCast(
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index bdae93b..b83748b 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -524,6 +524,22 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* 
buf_var, const PrimExp
   *p_alignment = align_bits / 8;
 }
 
+llvm::GlobalVariable* CodeGenLLVM::AllocateSharedMemory(DataType dtype, size_t 
size,
+                                                        unsigned int 
shared_address_space,
+                                                        int alignment,
+                                                        
llvm::GlobalValue::LinkageTypes linkage) {
+  llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(dtype), size);
+  llvm::GlobalVariable* global =
+      new llvm::GlobalVariable(*module_, type, false, linkage, nullptr, 
"shmem", nullptr,
+                               llvm::GlobalValue::NotThreadLocal, 
shared_address_space);
+#if TVM_LLVM_VERSION >= 100
+  global->setAlignment(llvm::Align(alignment));
+#else
+  global->setAlignment(alignment);
+#endif
+  return global;
+}
+
 std::unique_ptr<CodeGenLLVM::DebugInfo> 
CodeGenLLVM::CreateDebugInfo(llvm::Module* module) {
 #if TVM_LLVM_VERSION >= 100
   auto debug_info = std::make_unique<CodeGenLLVM::DebugInfo>();
diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h
index 810e59b..52c5b98 100644
--- a/src/target/llvm/codegen_llvm.h
+++ b/src/target/llvm/codegen_llvm.h
@@ -292,6 +292,11 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const 
PrimExpr&)>,
                        const Var& loop_var, const Stmt& body);
   // add alias information.
   void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr 
index);
+
+  llvm::GlobalVariable* AllocateSharedMemory(DataType dtype, size_t size,
+                                             unsigned int 
shared_address_space, int alignment,
+                                             llvm::GlobalValue::LinkageTypes 
linkage);
+
   // The IRBuilder.
   using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, 
llvm::IRBuilderDefaultInserter>;
   // The current function
diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc
index 43ea0e6..15543ed 100644
--- a/src/target/llvm/codegen_nvptx.cc
+++ b/src/target/llvm/codegen_nvptx.cc
@@ -48,48 +48,44 @@ class CodeGenNVPTX : public CodeGenLLVM {
   void VisitStmt_(const AllocateNode* op) final {
     ICHECK(!is_zero(op->condition));
     llvm::Value* buf = nullptr;
-
-    int32_t constant_size = op->constant_allocation_size();
-    ICHECK_GT(constant_size, 0) << "Can only handle constant size stack 
allocation in GPU";
     StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
-    if (constant_size % 4 == 0 && info.alignment == 0) {
-      info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
-    }
     // maximum necessary alignment in the NV devices
     if (info.alignment > 16) {
       info.alignment = 16;
     }
+
     auto storage_scope = 
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
-    if (storage_scope.rank == runtime::StorageRank::kLocal) {
-      // const int local_address_space = 5;
-      // TODO(tqchen): for higher version of LLVM, local address space can be 
set.
-      llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
-        return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), 
ConstInt32(constant_size));
-      });
-      if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
-#if TVM_LLVM_VERSION >= 100
-        alloca->setAlignment(llvm::Align(info.alignment));
-#else
-        alloca->setAlignment(info.alignment);
-#endif
-      }
-      buf = alloca;
-    } else {
-      ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
-          << "Can only allocate shared or local memory inside kernel";
+    if (storage_scope.rank == runtime::StorageRank::kShared && 
storage_scope.tag == ".dyn") {
       // Shared memory: address space  == 3
-      const unsigned shared_address_space = 3;
-      llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), 
constant_size);
-      // Allocate shared memory in global, address_space = 3
-      llvm::GlobalVariable* global = new llvm::GlobalVariable(
-          *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, 
".shared", nullptr,
-          llvm::GlobalValue::NotThreadLocal, shared_address_space);
+      buf =
+          AllocateSharedMemory(op->dtype, 0, 3, info.alignment, 
llvm::GlobalValue::ExternalLinkage);
+    } else {
+      int32_t constant_size = op->constant_allocation_size();
+      ICHECK_GT(constant_size, 0) << "Can only handle constant size stack 
allocation in GPU";
+
+      if (constant_size % 4 == 0 && info.alignment == 0) {
+        info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
+      }
+      if (storage_scope.rank == runtime::StorageRank::kLocal) {
+        // const int local_address_space = 5;
+        // TODO(tqchen): for higher version of LLVM, local address space can 
be set.
+        llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
+          return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), 
ConstInt32(constant_size));
+        });
+        if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
 #if TVM_LLVM_VERSION >= 100
-      global->setAlignment(llvm::Align(info.alignment));
+          alloca->setAlignment(llvm::Align(info.alignment));
 #else
-      global->setAlignment(info.alignment);
+          alloca->setAlignment(info.alignment);
 #endif
-      buf = global;
+        }
+        buf = alloca;
+      } else {
+        ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
+            << "Can only allocate shared or local memory inside kernel";
+        buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment,
+                                   llvm::GlobalValue::PrivateLinkage);
+      }
     }
 
     buf = builder_->CreatePointerCast(
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index d7dcbec..7897490 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -525,6 +525,8 @@ void CodeGenCUDA::PrintStorageScope(const std::string& 
scope, std::ostream& os)
                                 "all global arrays as input instead";
   if (scope == "shared") {
     os << "__shared__ ";
+  } else if (scope == "shared.dyn") {
+    os << "extern __shared__ ";
   }
 }
 
@@ -703,9 +705,8 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
   std::string vid = AllocVarID(op->buffer_var.get());
 
   this->PrintIndent();
-  int32_t constant_size = op->constant_allocation_size();
-  ICHECK_GT(constant_size, 0) << "Can only handle constant size stack 
allocation for now";
   std::string scope = GetPtrStorageScope(op->buffer_var);
+  const VarNode* buffer = op->buffer_var.as<VarNode>();
   if (scope.find("wmma.") == 0) {
     if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
       ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) 
||
@@ -719,19 +720,28 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
              op->dtype == DataType::Int(32))
           << "Accumulator only support half, float and int type for now";
     }
-    const VarNode* buffer = op->buffer_var.as<VarNode>();
-    constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
     PrintWmmaScope(scope, op->dtype, buffer, stream);
   } else {
     PrintStorageScope(scope, stream);
     PrintType(op->dtype, stream);
   }
-  if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
-       op->dtype == DataType::Int(1)) &&
-      scope == "shared") {
-    constant_size = constant_size / (32 / op->dtype.bits());
+
+  if (scope == "shared.dyn") {
+    stream << ' ' << vid << "[];\n";
+  } else {
+    int32_t constant_size = op->constant_allocation_size();
+    ICHECK_GT(constant_size, 0) << "Can only handle constant size stack 
allocation for now";
+
+    if (scope.find("wmma.") == 0) {
+      constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
+    }
+    if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
+         op->dtype == DataType::Int(1)) &&
+        scope == "shared") {
+      constant_size = constant_size / (32 / op->dtype.bits());
+    }
+    stream << ' ' << vid << '[' << constant_size << "];\n";
   }
-  stream << ' ' << vid << '[' << constant_size << "];\n";
 
   RegisterHandleType(op->buffer_var.get(), op->dtype);
   this->PrintStmt(op->body);
diff --git a/src/tir/transforms/lower_device_storage_access_info.cc 
b/src/tir/transforms/lower_device_storage_access_info.cc
index 829b7d8..eafed83 100644
--- a/src/tir/transforms/lower_device_storage_access_info.cc
+++ b/src/tir/transforms/lower_device_storage_access_info.cc
@@ -67,7 +67,7 @@ class StorageAccessInfoLower : public StmtExprMutator {
       StorageScope scope = 
StorageScope::Create(op->value.as<StringImmNode>()->value);
       StorageEntry e;
       e.scope = scope;
-      if (scope.tag.length() != 0) {
+      if (scope.tag.length() != 0 && scope.tag != ".dyn") {
         e.info = GetMemoryInfo(op->value.as<StringImmNode>()->value);
         ICHECK(e.info.defined()) << "Cannot find memory info of " << 
scope.to_string();
       }
diff --git a/src/tir/transforms/split_host_device.cc 
b/src/tir/transforms/split_host_device.cc
index f01d987..795ae9d 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -33,6 +33,9 @@
 
 #include <unordered_map>
 
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_utils.h"
+
 namespace tvm {
 namespace tir {
 
@@ -89,6 +92,17 @@ class VarUseDefAnalysis : public StmtExprMutator {
 
   Stmt VisitStmt_(const AllocateNode* op) final {
     this->HandleDef(op->buffer_var.get());
+    auto storage_scope = 
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
+    if (storage_scope.rank == runtime::StorageRank::kShared && 
storage_scope.tag == ".dyn") {
+      ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory 
allocation is allowed.";
+      ICHECK_GT(op->extents.size(), 0);
+      dyn_shmem_size_ = op->extents[0];
+      for (size_t i = 1; i < op->extents.size(); ++i) {
+        dyn_shmem_size_ *= op->extents[i];
+      }
+      dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes());
+      use_dyn_shmem_ = true;
+    }
     return StmtExprMutator::VisitStmt_(op);
   }
 
@@ -175,6 +189,8 @@ class VarUseDefAnalysis : public StmtExprMutator {
   Array<Var> undefined_;
   Array<IterVar> thread_axis_;
   Array<PrimExpr> thread_extent_;
+  PrimExpr dyn_shmem_size_{0};
+  bool use_dyn_shmem_{false};
   std::unordered_map<const VarNode*, int> use_count_;
   std::unordered_map<const VarNode*, int> def_count_;
 
@@ -262,6 +278,10 @@ class HostDeviceSplitter : public StmtMutator {
         WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, 
runtime::String(kernel_symbol));
     device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, 
Integer(1));
     device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, 
device_target_);
+    if (m.use_dyn_shmem_) {
+      device_func =
+          WithAttr(std::move(device_func), 
tir::attr::kDeviceUseDynSharedMemory, Integer(1));
+    }
     (*device_mod_)->Add(GlobalVar(kernel_symbol), device_func);
 
     // generate calls to the device function
@@ -273,6 +293,9 @@ class HostDeviceSplitter : public StmtMutator {
     for (PrimExpr ext : m.thread_extent_) {
       call_args.push_back(ext);
     }
+    if (m.use_dyn_shmem_) {
+      call_args.push_back(m.dyn_shmem_size_);
+    }
     return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), 
call_args));
   }
 
diff --git a/src/tir/transforms/storage_rewrite.cc 
b/src/tir/transforms/storage_rewrite.cc
index 613d026..b216b8b 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -512,7 +512,7 @@ class StoragePlanRewriter : public StmtExprMutator {
       // try to find merge, for tagged memory
       for (size_t i = 0; i < vec.size(); ++i) {
         StorageEntry* e = vec[i];
-        if (e->scope.tag.length() != 0) {
+        if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") {
           ICHECK_NE(e->const_nbits, 0U) << "Special tagged memory must be 
const size";
           for (size_t j = 0; j < i; ++j) {
             if (e->scope == vec[j]->scope) {
@@ -546,7 +546,7 @@ class StoragePlanRewriter : public StmtExprMutator {
                               make_const(DataType::Int(32), 1), 
e->allocs[0]->extents);
           e->new_alloc =
               Allocate(e->alloc_var, alloc_type, {sz}, 
e->allocs[0]->condition, Evaluate(0));
-          if (e->scope.tag.length() != 0) {
+          if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") {
             MemoryInfo info = GetMemoryInfo(e->scope.to_string());
             uint64_t total_elem = e->const_nbits / e->elem_type.bits();
             ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
@@ -587,7 +587,7 @@ class StoragePlanRewriter : public StmtExprMutator {
           combo_size = analyzer_.Simplify(combo_size);
           e->new_alloc =
               Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), 
Evaluate(0));
-          if (e->scope.tag.length() != 0) {
+          if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") {
             MemoryInfo info = GetMemoryInfo(e->scope.to_string());
             uint64_t total_elem = e->const_nbits / e->elem_type.bits();
             ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
diff --git a/tests/python/unittest/test_tir_ir_builder.py 
b/tests/python/unittest/test_tir_ir_builder.py
index 355d3ab..0329134 100644
--- a/tests/python/unittest/test_tir_ir_builder.py
+++ b/tests/python/unittest/test_tir_ir_builder.py
@@ -18,6 +18,7 @@ import tvm
 from tvm import te
 import numpy as np
 import tvm.testing
+from tvm.topi.math import cast
 
 
 def test_for():
@@ -497,6 +498,62 @@ def test_while_binary_search():
     check_target("vulkan", searchsorted_ir_gpu)
 
 
[email protected]_gpu
+def test_dyn_shared():
+    n = te.size_var("n")
+    dtype = "float32"
+    A = te.placeholder((n,), name="A")
+
+    def test_device_ir(A, B):
+        n = A.shape[0]
+        ib = tvm.tir.ir_builder.create()
+
+        tx = te.thread_axis("threadIdx.x")
+        ib.scope_attr(tx, "thread_extent", n)
+
+        temp = ib.allocate(dtype, (n,), scope="shared.dyn")  # n is symbolic 
size
+
+        Aptr = ib.buffer_ptr(A)
+        Bptr = ib.buffer_ptr(B)
+
+        temp[tx] = Aptr[tx]
+        depth = tvm.tir.log2(cast(n, "float32"))
+
+        with ib.for_range(0, depth) as i:
+            ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync", 
tvm.runtime.convert(["shared"])))
+            d = n >> (i + 1)
+            with ib.if_scope(tx < d):
+                temp[tx] += temp[tx + d]
+
+        Bptr[0] = temp[0]
+        return ib.get()
+
+    B = te.extern(
+        (1,),
+        [A],
+        lambda ins, outs: test_device_ir(ins[0], outs[0]),
+        name="reduce",
+        dtype=dtype,
+    )
+    s = te.create_schedule(B.op)
+
+    def check_target(target):
+        if not tvm.testing.device_enabled(target):
+            return
+
+        freduce = tvm.build(s, [A, B], target)
+        dev = tvm.device(target, 0)
+
+        for n in [512, 1024]:
+            a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
+            b = tvm.nd.array(np.zeros(1, dtype=B.dtype), dev)
+            freduce(a, b)
+            tvm.testing.assert_allclose(b.numpy()[0], np.sum(a.numpy()), 1e-4, 
1e-4)
+
+    for target in ["cuda", "nvptx"]:
+        check_target(target)
+
+
 if __name__ == "__main__":
     test_prefetch()
     test_if()
@@ -507,3 +564,4 @@ if __name__ == "__main__":
     test_while_collatz()
     test_while_mandel()
     test_while_binary_search()
+    test_dyn_shared()
diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts
index f12837f..226797e 100644
--- a/web/src/webgpu.ts
+++ b/web/src/webgpu.ts
@@ -39,7 +39,7 @@ export async function detectGPUDevice(): Promise<GPUDevice | 
undefined | null> {
 interface FunctionInfo {
   name: string;
   arg_types: Array<string>;
-  thread_axis_tags: Array<string>;
+  launch_param_tags: Array<string>;
 }
 
 /**
@@ -114,8 +114,8 @@ export class WebGPUContext {
 
     const dispatchToDim: Array<number> = [];
 
-    for (let i = 0; i < finfo.thread_axis_tags.length; ++i) {
-      const tag: string = finfo.thread_axis_tags[i];
+    for (let i = 0; i < finfo.launch_param_tags.length; ++i) {
+      const tag: string = finfo.launch_param_tags[i];
       if (tag.startsWith("blockIdx.")) {
         const target: number = tag.charCodeAt(tag.length - 1) - 
("x".charCodeAt(0));
         assert(target >= 0 && target < 3);

Reply via email to