This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e95f10f [CUDA] Initial support for dynamic shared memory (#8466)
e95f10f is described below
commit e95f10f8fdc30a02ae004dc8a7aad72538b38d08
Author: masahi <[email protected]>
AuthorDate: Thu Jul 22 14:26:49 2021 +0900
[CUDA] Initial support for dynamic shared memory (#8466)
* send dyn shmem size to runtime
* add dyn shared storage scope
* associate buffer var and its storage scoe in split_host_device
* tried NVPTX but failed with INVALID_PTX error
* test stub
* dynamic shmem reduce working
* log2 issue fixed
* nvptx working
* refactor llvm shmem allocation
* make linkage argument
* support rocm too
* send dyn shmem param to hip runtime
* remove alloc map from split_host_device.cc
* remove attr::storage_scope from split_host_device
* lint fix
* formatting
* update calling convention doc
* minor update to test
* remove log
* remove kDynShared, dyn.shared -> shared.dyn
* support backward compat
* update json/binary reader/writer
* thread_axis_tags -> launch_param_tags
* ThreadAxisConfig -> LaunchParamConfig
* remove use_dynamic_shared_memory from FunctionInfo meta data
* revert change in test_tir_ir_builder.py
* make sure kUseDynamicSharedMemoryTag is the last tag
* remove continue
* update doc string following name change
* more comment update following name change
Co-authored-by: masa <[email protected]>
Co-authored-by: Masahiro Masuda <masahi@[email protected]>
---
docs/dev/codebase_walkthrough.rst | 4 +-
include/tvm/tir/function.h | 11 +++-
src/runtime/cuda/cuda_module.cc | 14 ++---
src/runtime/file_utils.cc | 10 ++--
src/runtime/meta_data.h | 5 +-
src/runtime/metal/metal_module.mm | 12 ++--
src/runtime/opencl/opencl_module.cc | 14 ++---
src/runtime/rocm/rocm_module.cc | 19 +++---
src/runtime/thread_storage_scope.h | 33 +++++++---
src/runtime/vulkan/vulkan_wrapped_func.cc | 8 +--
src/runtime/vulkan/vulkan_wrapped_func.h | 7 +--
src/target/build_common.h | 7 ++-
src/target/llvm/codegen_amdgpu.cc | 70 ++++++++++------------
src/target/llvm/codegen_llvm.cc | 16 +++++
src/target/llvm/codegen_llvm.h | 5 ++
src/target/llvm/codegen_nvptx.cc | 60 +++++++++----------
src/target/source/codegen_cuda.cc | 28 ++++++---
.../transforms/lower_device_storage_access_info.cc | 2 +-
src/tir/transforms/split_host_device.cc | 23 +++++++
src/tir/transforms/storage_rewrite.cc | 6 +-
tests/python/unittest/test_tir_ir_builder.py | 58 ++++++++++++++++++
web/src/webgpu.ts | 6 +-
22 files changed, 277 insertions(+), 141 deletions(-)
diff --git a/docs/dev/codebase_walkthrough.rst
b/docs/dev/codebase_walkthrough.rst
index 60ab5e5..efc8b32 100644
--- a/docs/dev/codebase_walkthrough.rst
+++ b/docs/dev/codebase_walkthrough.rst
@@ -183,7 +183,7 @@ The first time you invoke the compiled module with
``fadd(a, b, c)``, ``GetFunct
auto it = fmap_.find(name);
const FunctionInfo& info = it->second;
CUDAWrappedFunc f;
- f.Init(this, sptr_to_self, name, info.arg_types.size(),
info.thread_axis_tags);
+ f.Init(this, sptr_to_self, name, info.arg_types.size(),
info.launch_param_tags);
return PackFuncVoidAddr(f, info.arg_types);
}
@@ -204,7 +204,7 @@ The ``PackedFunc``'s overloaded ``operator()`` will be
called, which in turn cal
fcache_[device_id] = m_->GetFunc(device_id, func_name_);
}
CUstream strm =
static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
- ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
+ ThreadWorkLoad wl = launch_param_config_.Extract(args);
CUresult result = cuLaunchKernel(
fcache_[device_id],
wl.grid_dim(0),
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index 25ed2f9..55f4fc6 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -240,10 +240,12 @@ namespace attr {
*
* Call(f,
* [arg1, arg2, ..., arg_n,
- * work_size_1, work_size_2, ... work_size_m])
+ * work_size_1, work_size_2, ... work_size_m, dyn_shmem_size])
*
* Here n = len(arg), m = len(work_size) = len(device_thread_axis).
*
+ * When kDeviceUseDynSharedMemory is not set, dyn_shmem_size argument is
omitted.
+ *
* The list of device_thread_axis indicates how can be bind the
* work_size arguments to the corresponding threads.
*
@@ -252,6 +254,13 @@ namespace attr {
constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis";
/*!
+ * \brief Whether or not use dynamic shared memory.
+ *
+ * Type: Integer
+ */
+constexpr const char* kDeviceUseDynSharedMemory =
"tir.device_use_dyn_shared_memory";
+
+/*!
* \brief Whether to set noalias rule on the function arguments.
*
* Type: Integer
diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc
index a877bc6..7d6879a 100644
--- a/src/runtime/cuda/cuda_module.cc
+++ b/src/runtime/cuda/cuda_module.cc
@@ -153,12 +153,12 @@ class CUDAWrappedFunc {
public:
// initialize the CUDA function.
void Init(CUDAModuleNode* m, ObjectPtr<Object> sptr, const std::string&
func_name,
- size_t num_void_args, const std::vector<std::string>&
thread_axis_tags) {
+ size_t num_void_args, const std::vector<std::string>&
launch_param_tags) {
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
std::fill(fcache_.begin(), fcache_.end(), nullptr);
- thread_axis_cfg_.Init(num_void_args, thread_axis_tags);
+ launch_param_config_.Init(num_void_args, launch_param_tags);
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
@@ -168,10 +168,10 @@ class CUDAWrappedFunc {
fcache_[device_id] = m_->GetFunc(device_id, func_name_);
}
CUstream strm =
static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
- ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
+ ThreadWorkLoad wl = launch_param_config_.Extract(args);
CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0),
wl.grid_dim(1),
wl.grid_dim(2), wl.block_dim(0),
wl.block_dim(1),
- wl.block_dim(2), 0, strm, void_args,
nullptr);
+ wl.block_dim(2), wl.dyn_shmem_size, strm,
void_args, nullptr);
if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) {
const char* msg;
cuGetErrorName(result, &msg);
@@ -201,8 +201,8 @@ class CUDAWrappedFunc {
// Device function cache per device.
// mark as mutable, to enable lazy initialization
mutable std::array<CUfunction, kMaxNumGPUs> fcache_;
- // thread axis configuration
- ThreadAxisConfig thread_axis_cfg_;
+ // launch parameters configuration
+ LaunchParamConfig launch_param_config_;
};
class CUDAPrepGlobalBarrier {
@@ -241,7 +241,7 @@ PackedFunc CUDAModuleNode::GetFunction(const std::string&
name,
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
CUDAWrappedFunc f;
- f.Init(this, sptr_to_self, name, info.arg_types.size(),
info.thread_axis_tags);
+ f.Init(this, sptr_to_self, name, info.arg_types.size(),
info.launch_param_tags);
return PackFuncVoidAddr(f, info.arg_types);
}
diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc
index 32dd1d8..35832e8 100644
--- a/src/runtime/file_utils.cc
+++ b/src/runtime/file_utils.cc
@@ -43,7 +43,7 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("arg_types", sarg_types);
- writer->WriteObjectKeyValue("thread_axis_tags", thread_axis_tags);
+ writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags);
writer->EndObject();
}
@@ -52,7 +52,9 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
std::vector<std::string> sarg_types;
helper.DeclareField("name", &name);
helper.DeclareField("arg_types", &sarg_types);
- helper.DeclareField("thread_axis_tags", &thread_axis_tags);
+ helper.DeclareOptionalField("launch_param_tags", &launch_param_tags);
+ helper.DeclareOptionalField("thread_axis_tags",
+ &launch_param_tags); // for backward
compatibility
helper.ReadAllFields(reader);
arg_types.resize(sarg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
@@ -63,13 +65,13 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
void FunctionInfo::Save(dmlc::Stream* writer) const {
writer->Write(name);
writer->Write(arg_types);
- writer->Write(thread_axis_tags);
+ writer->Write(launch_param_tags);
}
bool FunctionInfo::Load(dmlc::Stream* reader) {
if (!reader->Read(&name)) return false;
if (!reader->Read(&arg_types)) return false;
- if (!reader->Read(&thread_axis_tags)) return false;
+ if (!reader->Read(&launch_param_tags)) return false;
return true;
}
diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h
index e3ec155..002012a 100644
--- a/src/runtime/meta_data.h
+++ b/src/runtime/meta_data.h
@@ -99,11 +99,14 @@ Module MetadataModuleCreate(
const std::unordered_map<std::string, NDArray>& metadata,
const std::unordered_map<std::string, std::vector<std::string>>& sym_vars);
+/*! \brief A tag to specify whether or not dynamic shared memory is used */
+constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory";
+
/*! \brief function information needed by device */
struct FunctionInfo {
std::string name;
std::vector<DLDataType> arg_types;
- std::vector<std::string> thread_axis_tags;
+ std::vector<std::string> launch_param_tags;
void Save(dmlc::JSONWriter* writer) const;
void Load(dmlc::JSONReader* reader);
diff --git a/src/runtime/metal/metal_module.mm
b/src/runtime/metal/metal_module.mm
index 8850188..1e81ac1 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -178,7 +178,7 @@ class MetalWrappedFunc {
// initialize the METAL function.
void Init(MetalModuleNode* m, ObjectPtr<Object> sptr, const std::string&
func_name,
size_t num_buffer_args, size_t num_pack_args,
- const std::vector<std::string>& thread_axis_tags) {
+ const std::vector<std::string>& launch_param_tags) {
w_ = metal::MetalWorkspace::Global();
m_ = m;
sptr_ = sptr;
@@ -186,7 +186,7 @@ class MetalWrappedFunc {
num_buffer_args_ = num_buffer_args;
num_pack_args_ = num_pack_args;
std::fill(scache_.begin(), scache_.end(),
(id<MTLComputePipelineState>)nil);
- thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
+ launch_param_config_.Init(num_buffer_args + num_pack_args,
launch_param_tags);
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
int dev_id = t->device.device_id;
scache_[dev_id] = m->GetPipelineState(dev_id, func_name);
@@ -201,7 +201,7 @@ class MetalWrappedFunc {
if (scache_[device_id] == nil) {
scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
}
- ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
+ ThreadWorkLoad wl = launch_param_config_.Extract(args);
int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
auto maxTotalThreadsPerThreadgroup =
scache_[device_id].maxTotalThreadsPerThreadgroup;
CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup);
@@ -242,8 +242,8 @@ class MetalWrappedFunc {
// Device state cache per device.
// mark as mutable, to enable lazy initialization
mutable std::array<id<MTLComputePipelineState>, kMetalMaxNumDevice> scache_;
- // thread axis configuration
- ThreadAxisConfig thread_axis_cfg_;
+ // launch parameters configuration
+ LaunchParamConfig launch_param_config_;
};
PackedFunc MetalModuleNode::GetFunction(const std::string& name,
@@ -261,7 +261,7 @@ PackedFunc MetalModuleNode::GetFunction(const std::string&
name,
MetalWrappedFunc f;
size_t num_buffer_args = NumBufferArgs(info.arg_types);
f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() -
num_buffer_args,
- info.thread_axis_tags);
+ info.launch_param_tags);
pf = PackFuncNonBufferArg(f, info.arg_types);
};
return pf;
diff --git a/src/runtime/opencl/opencl_module.cc
b/src/runtime/opencl/opencl_module.cc
index 4040d82..f6c7f62 100644
--- a/src/runtime/opencl/opencl_module.cc
+++ b/src/runtime/opencl/opencl_module.cc
@@ -40,14 +40,14 @@ class OpenCLWrappedFunc {
// initialize the OpenCL function.
void Init(OpenCLModuleNode* m, ObjectPtr<Object> sptr,
OpenCLModuleNode::KTRefEntry entry,
std::string func_name, std::vector<size_t> arg_size,
- const std::vector<std::string>& thread_axis_tags) {
+ const std::vector<std::string>& launch_param_tags) {
w_ = m->GetGlobalWorkspace();
m_ = m;
sptr_ = sptr;
entry_ = entry;
func_name_ = func_name;
arg_size_ = arg_size;
- thread_axis_cfg_.Init(arg_size.size(), thread_axis_tags);
+ launch_param_config_.Init(arg_size.size(), launch_param_tags);
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
@@ -73,8 +73,8 @@ class OpenCLWrappedFunc {
OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], arg));
}
cl_command_queue queue = w_->GetQueue(t->device);
- ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
- cl_uint work_dim = static_cast<cl_uint>(thread_axis_cfg_.work_dim());
+ ThreadWorkLoad wl = launch_param_config_.Extract(args);
+ cl_uint work_dim = static_cast<cl_uint>(launch_param_config_.work_dim());
for (cl_uint i = 0; i < work_dim; ++i) {
wl.work_size[i] *= wl.work_size[i + 3];
}
@@ -96,8 +96,8 @@ class OpenCLWrappedFunc {
std::string func_name_;
// convert code for void argument
std::vector<size_t> arg_size_;
- // thread axis config
- ThreadAxisConfig thread_axis_cfg_;
+ // launch parameters config
+ LaunchParamConfig launch_param_config_;
};
OpenCLModuleNode::~OpenCLModuleNode() {
@@ -148,7 +148,7 @@ PackedFunc OpenCLModuleNode::GetFunction(const std::string&
name,
}
}
// initialize the wrapped func.
- f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size,
info.thread_axis_tags);
+ f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size,
info.launch_param_tags);
return PackFuncVoidAddr(f, info.arg_types);
}
diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc
index 567557c..487ad23 100644
--- a/src/runtime/rocm/rocm_module.cc
+++ b/src/runtime/rocm/rocm_module.cc
@@ -147,12 +147,12 @@ class ROCMWrappedFunc {
public:
// initialize the ROCM function.
void Init(ROCMModuleNode* m, ObjectPtr<Object> sptr, const std::string&
func_name,
- size_t num_void_args, const std::vector<std::string>&
thread_axis_tags) {
+ size_t num_void_args, const std::vector<std::string>&
launch_param_tags) {
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
std::fill(fcache_.begin(), fcache_.end(), nullptr);
- thread_axis_cfg_.Init(num_void_args, thread_axis_tags);
+ launch_param_config_.Init(num_void_args, launch_param_tags);
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t
packed_nbytes) const {
@@ -164,13 +164,14 @@ class ROCMWrappedFunc {
hipStream_t strm =
static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);
- ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
+ ThreadWorkLoad wl = launch_param_config_.Extract(args);
void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
&packed_nbytes, HIP_LAUNCH_PARAM_END};
// HIP supports only extra_args.
- ROCM_DRIVER_CALL(hipModuleLaunchKernel(
- fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2),
wl.block_dim(0),
- wl.block_dim(1), wl.block_dim(2), 0, strm, nullptr,
reinterpret_cast<void**>(&config)));
+ ROCM_DRIVER_CALL(hipModuleLaunchKernel(fcache_[device_id], wl.grid_dim(0),
wl.grid_dim(1),
+ wl.grid_dim(2), wl.block_dim(0),
wl.block_dim(1),
+ wl.block_dim(2), wl.dyn_shmem_size,
strm, nullptr,
+ reinterpret_cast<void**>(&config)));
}
private:
@@ -183,8 +184,8 @@ class ROCMWrappedFunc {
// Device function cache per device.
// mark as mutable, to enable lazy initialization
mutable std::array<hipFunction_t, kMaxNumGPUs> fcache_;
- // thread axis configuration
- ThreadAxisConfig thread_axis_cfg_;
+ // launch parameters configuration
+ LaunchParamConfig launch_param_config_;
};
PackedFunc ROCMModuleNode::GetFunction(const std::string& name,
@@ -195,7 +196,7 @@ PackedFunc ROCMModuleNode::GetFunction(const std::string&
name,
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
ROCMWrappedFunc f;
- f.Init(this, sptr_to_self, name, info.arg_types.size(),
info.thread_axis_tags);
+ f.Init(this, sptr_to_self, name, info.arg_types.size(),
info.launch_param_tags);
return PackFuncPackedArg(f, info.arg_types);
}
diff --git a/src/runtime/thread_storage_scope.h
b/src/runtime/thread_storage_scope.h
index 9d140ae..ac8260f 100644
--- a/src/runtime/thread_storage_scope.h
+++ b/src/runtime/thread_storage_scope.h
@@ -19,7 +19,7 @@
/*!
* \file thread_storage_scope.h
- * \brief Extract thread axis configuration from TVMArgs.
+ * \brief Extract launch parameters configuration from TVMArgs.
*/
#ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_
#define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_
@@ -29,6 +29,8 @@
#include <string>
#include <vector>
+#include "meta_data.h"
+
namespace tvm {
namespace runtime {
@@ -182,6 +184,8 @@ struct ThreadScope {
struct ThreadWorkLoad {
// array, first three are thread configuration.
size_t work_size[6];
+ // Dynamic shared memory allocation size in bytes.
+ size_t dyn_shmem_size{0};
/*!
* \param i The block dimension.
* \return i-th block dim
@@ -193,17 +197,23 @@ struct ThreadWorkLoad {
*/
inline size_t grid_dim(size_t i) const { return work_size[i]; }
};
-/*! \brief Thread axis configuration */
-class ThreadAxisConfig {
+/*! \brief Launch parameters configuration */
+class LaunchParamConfig {
public:
- void Init(size_t base, const std::vector<std::string>& thread_axis_tags) {
+ void Init(size_t base, const std::vector<std::string>& launch_param_tags) {
base_ = base;
std::vector<bool> filled(6, false);
- for (size_t i = 0; i < thread_axis_tags.size(); ++i) {
- const std::string& tag = thread_axis_tags[i];
- ThreadScope ts = ThreadScope::Create(tag);
- arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
- filled[ts.rank * 3 + ts.dim_index] = true;
+ for (size_t i = 0; i < launch_param_tags.size(); ++i) {
+ const std::string& tag = launch_param_tags[i];
+ if (tag == kUseDynamicSharedMemoryTag) {
+ ICHECK_EQ(i, launch_param_tags.size() - 1)
+ << "kUseDynamicSharedMemoryTag should be the last tag in
launch_param_tags.";
+ use_dyn_shared_memory_ = true;
+ } else {
+ ThreadScope ts = ThreadScope::Create(tag);
+ arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
+ filled[ts.rank * 3 + ts.dim_index] = true;
+ }
}
work_dim_ = 1;
for (int i = 0; i < 3; ++i) {
@@ -223,6 +233,9 @@ class ThreadAxisConfig {
w.work_size[arg_index_map_[i]] = size;
}
}
+ if (use_dyn_shared_memory_) {
+ w.dyn_shmem_size = static_cast<size_t>(x.values[base_ +
arg_index_map_.size()].v_int64);
+ }
return w;
}
// return the work dim
@@ -235,6 +248,8 @@ class ThreadAxisConfig {
size_t work_dim_;
/*! \brief The index mapping. */
std::vector<uint32_t> arg_index_map_;
+ /*! \brief Whether or not use dynamic shared memory. */
+ bool use_dyn_shared_memory_{false};
};
} // namespace runtime
diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc
b/src/runtime/vulkan/vulkan_wrapped_func.cc
index 103b2aa..0712f72 100644
--- a/src/runtime/vulkan/vulkan_wrapped_func.cc
+++ b/src/runtime/vulkan/vulkan_wrapped_func.cc
@@ -33,13 +33,13 @@ namespace vulkan {
void VulkanWrappedFunc::Init(VulkanModuleNode* m, ObjectPtr<Object> sptr,
const std::string& func_name, size_t
num_buffer_args,
size_t num_pack_args,
- const std::vector<std::string>& thread_axis_tags)
{
+ const std::vector<std::string>&
launch_param_tags) {
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
num_buffer_args_ = num_buffer_args;
num_pack_args_ = num_pack_args;
- thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
+ launch_param_config_.Init(num_buffer_args + num_pack_args,
launch_param_tags);
}
void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
@@ -50,7 +50,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue*
rv,
scache_[device_id] = m_->GetPipeline(device_id, func_name_,
num_pack_args_);
}
const auto& pipeline = scache_[device_id];
- ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
+ ThreadWorkLoad wl = launch_param_config_.Extract(args);
std::vector<VkDescriptorBufferInfo> descriptor_buffers;
descriptor_buffers.resize(num_buffer_args_);
for (size_t i = 0; i < num_buffer_args_; ++i) {
@@ -197,7 +197,7 @@ PackedFunc VulkanModuleNode::GetFunction(const std::string&
name,
VulkanWrappedFunc f;
size_t num_buffer_args = NumBufferArgs(info.arg_types);
f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() -
num_buffer_args,
- info.thread_axis_tags);
+ info.launch_param_tags);
return PackFuncNonBufferArg(std::move(f), info.arg_types);
}
diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h
b/src/runtime/vulkan/vulkan_wrapped_func.h
index a174f22..cd4774b 100644
--- a/src/runtime/vulkan/vulkan_wrapped_func.h
+++ b/src/runtime/vulkan/vulkan_wrapped_func.h
@@ -58,7 +58,7 @@ class VulkanWrappedFunc {
public:
void Init(VulkanModuleNode* m, ObjectPtr<Object> sptr, const std::string&
func_name,
size_t num_buffer_args, size_t num_pack_args,
- const std::vector<std::string>& thread_axis_tags);
+ const std::vector<std::string>& launch_param_tags);
void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args)
const;
@@ -73,11 +73,10 @@ class VulkanWrappedFunc {
size_t num_buffer_args_;
// number of packed arguments.
size_t num_pack_args_;
+ // launch parameters configuration
+ LaunchParamConfig launch_param_config_;
// Device state cache per device.
// mark as mutable, to enable lazy initialization
- // thread axis configuration
- ThreadAxisConfig thread_axis_cfg_;
-
mutable std::array<std::shared_ptr<VulkanPipeline>, kVulkanMaxNumDevice>
scache_;
};
diff --git a/src/target/build_common.h b/src/target/build_common.h
index d2fe646..c66c2b5 100644
--- a/src/target/build_common.h
+++ b/src/target/build_common.h
@@ -53,7 +53,12 @@ inline std::unordered_map<std::string,
runtime::FunctionInfo> ExtractFuncInfo(co
if (auto opt =
f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis)) {
auto thread_axis = opt.value();
for (size_t i = 0; i < thread_axis.size(); ++i) {
- info.thread_axis_tags.push_back(thread_axis[i]->thread_tag);
+ info.launch_param_tags.push_back(thread_axis[i]->thread_tag);
+ }
+ }
+ if (auto opt = f->GetAttr<Integer>(tir::attr::kDeviceUseDynSharedMemory)) {
+ if (opt.value()) {
+ info.launch_param_tags.push_back(runtime::kUseDynamicSharedMemoryTag);
}
}
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
diff --git a/src/target/llvm/codegen_amdgpu.cc
b/src/target/llvm/codegen_amdgpu.cc
index 9aec8f4..7770e42 100644
--- a/src/target/llvm/codegen_amdgpu.cc
+++ b/src/target/llvm/codegen_amdgpu.cc
@@ -72,51 +72,45 @@ class CodeGenAMDGPU : public CodeGenLLVM {
void VisitStmt_(const AllocateNode* op) final {
ICHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
-
- int32_t constant_size = op->constant_allocation_size();
- ICHECK_GT(constant_size, 0) << "Can only handle constant size stack
allocation in GPU";
-
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
- if (constant_size % 4 == 0 && info.alignment == 0) {
- info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
- }
- // maximum necessary alignment in the AMD devices
- if (info.alignment > 16) {
- info.alignment = 16;
- }
auto storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
- if (storage_scope.rank == runtime::StorageRank::kLocal) {
- // const int local_address_space = 5;
- // TODO(tqchen): for higher version of LLVM, local address space can be
set.
- llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
- return builder_->CreateAlloca(DTypeToLLVMType(op->dtype),
ConstInt32(constant_size));
- });
- if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
-#if TVM_LLVM_VERSION >= 100
- alloca->setAlignment(llvm::Align(info.alignment));
-#else
- alloca->setAlignment(info.alignment);
-#endif
- }
- buf = alloca;
+
+ if (storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn") {
+ LOG(WARNING) << "Dynamic shared memory support for rocm is
experimental.";
+ buf = AllocateSharedMemory(op->dtype, 0, 3, std::min(info.alignment, 16),
+ llvm::GlobalValue::ExternalLinkage);
} else {
- ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
- << "Can only allocate shared or local memory inside kernel";
- // Shared memory: address space == 3
- const unsigned shared_address_space = 3;
- llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype),
constant_size);
- // Allocate shared memory in global, address_space = 3
- llvm::GlobalVariable* global = new llvm::GlobalVariable(
- *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr,
".shared", nullptr,
- llvm::GlobalValue::NotThreadLocal, shared_address_space);
- if (global->getAlignment() < static_cast<uint32_t>(info.alignment)) {
+ int32_t constant_size = op->constant_allocation_size();
+ ICHECK_GT(constant_size, 0) << "Can only handle constant size stack
allocation in GPU";
+
+ if (constant_size % 4 == 0 && info.alignment == 0) {
+ info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
+ }
+ // maximum necessary alignment in the AMD devices
+ if (info.alignment > 16) {
+ info.alignment = 16;
+ }
+ if (storage_scope.rank == runtime::StorageRank::kLocal) {
+ // const int local_address_space = 5;
+ // TODO(tqchen): for higher version of LLVM, local address space can
be set.
+ llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
+ return builder_->CreateAlloca(DTypeToLLVMType(op->dtype),
ConstInt32(constant_size));
+ });
+ if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100
- global->setAlignment(llvm::Align(info.alignment));
+ alloca->setAlignment(llvm::Align(info.alignment));
#else
- global->setAlignment(info.alignment);
+ alloca->setAlignment(info.alignment);
#endif
+ }
+ buf = alloca;
+ } else {
+ ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
+ << "Can only allocate shared or local memory inside kernel";
+ // Shared memory: address space == 3
+ buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment,
+ llvm::GlobalValue::PrivateLinkage);
}
- buf = global;
}
buf = builder_->CreatePointerCast(
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index bdae93b..b83748b 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -524,6 +524,22 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode*
buf_var, const PrimExp
*p_alignment = align_bits / 8;
}
+llvm::GlobalVariable* CodeGenLLVM::AllocateSharedMemory(DataType dtype, size_t
size,
+ unsigned int
shared_address_space,
+ int alignment,
+
llvm::GlobalValue::LinkageTypes linkage) {
+ llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(dtype), size);
+ llvm::GlobalVariable* global =
+ new llvm::GlobalVariable(*module_, type, false, linkage, nullptr,
"shmem", nullptr,
+ llvm::GlobalValue::NotThreadLocal,
shared_address_space);
+#if TVM_LLVM_VERSION >= 100
+ global->setAlignment(llvm::Align(alignment));
+#else
+ global->setAlignment(alignment);
+#endif
+ return global;
+}
+
std::unique_ptr<CodeGenLLVM::DebugInfo>
CodeGenLLVM::CreateDebugInfo(llvm::Module* module) {
#if TVM_LLVM_VERSION >= 100
auto debug_info = std::make_unique<CodeGenLLVM::DebugInfo>();
diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h
index 810e59b..52c5b98 100644
--- a/src/target/llvm/codegen_llvm.h
+++ b/src/target/llvm/codegen_llvm.h
@@ -292,6 +292,11 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const
PrimExpr&)>,
const Var& loop_var, const Stmt& body);
// add alias information.
void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr
index);
+
+ llvm::GlobalVariable* AllocateSharedMemory(DataType dtype, size_t size,
+ unsigned int
shared_address_space, int alignment,
+ llvm::GlobalValue::LinkageTypes
linkage);
+
// The IRBuilder.
using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder,
llvm::IRBuilderDefaultInserter>;
// The current function
diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc
index 43ea0e6..15543ed 100644
--- a/src/target/llvm/codegen_nvptx.cc
+++ b/src/target/llvm/codegen_nvptx.cc
@@ -48,48 +48,44 @@ class CodeGenNVPTX : public CodeGenLLVM {
void VisitStmt_(const AllocateNode* op) final {
ICHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
-
- int32_t constant_size = op->constant_allocation_size();
- ICHECK_GT(constant_size, 0) << "Can only handle constant size stack
allocation in GPU";
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
- if (constant_size % 4 == 0 && info.alignment == 0) {
- info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
- }
// maximum necessary alignment in the NV devices
if (info.alignment > 16) {
info.alignment = 16;
}
+
auto storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
- if (storage_scope.rank == runtime::StorageRank::kLocal) {
- // const int local_address_space = 5;
- // TODO(tqchen): for higher version of LLVM, local address space can be
set.
- llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
- return builder_->CreateAlloca(DTypeToLLVMType(op->dtype),
ConstInt32(constant_size));
- });
- if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
-#if TVM_LLVM_VERSION >= 100
- alloca->setAlignment(llvm::Align(info.alignment));
-#else
- alloca->setAlignment(info.alignment);
-#endif
- }
- buf = alloca;
- } else {
- ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
- << "Can only allocate shared or local memory inside kernel";
+ if (storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn") {
// Shared memory: address space == 3
- const unsigned shared_address_space = 3;
- llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype),
constant_size);
- // Allocate shared memory in global, address_space = 3
- llvm::GlobalVariable* global = new llvm::GlobalVariable(
- *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr,
".shared", nullptr,
- llvm::GlobalValue::NotThreadLocal, shared_address_space);
+ buf =
+ AllocateSharedMemory(op->dtype, 0, 3, info.alignment,
llvm::GlobalValue::ExternalLinkage);
+ } else {
+ int32_t constant_size = op->constant_allocation_size();
+ ICHECK_GT(constant_size, 0) << "Can only handle constant size stack
allocation in GPU";
+
+ if (constant_size % 4 == 0 && info.alignment == 0) {
+ info.alignment = GetTempAllocaAlignment(op->dtype, constant_size);
+ }
+ if (storage_scope.rank == runtime::StorageRank::kLocal) {
+ // const int local_address_space = 5;
+ // TODO(tqchen): for higher version of LLVM, local address space can
be set.
+ llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
+ return builder_->CreateAlloca(DTypeToLLVMType(op->dtype),
ConstInt32(constant_size));
+ });
+ if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100
- global->setAlignment(llvm::Align(info.alignment));
+ alloca->setAlignment(llvm::Align(info.alignment));
#else
- global->setAlignment(info.alignment);
+ alloca->setAlignment(info.alignment);
#endif
- buf = global;
+ }
+ buf = alloca;
+ } else {
+ ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
+ << "Can only allocate shared or local memory inside kernel";
+ buf = AllocateSharedMemory(op->dtype, constant_size, 3, info.alignment,
+ llvm::GlobalValue::PrivateLinkage);
+ }
}
buf = builder_->CreatePointerCast(
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index d7dcbec..7897490 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -525,6 +525,8 @@ void CodeGenCUDA::PrintStorageScope(const std::string&
scope, std::ostream& os)
"all global arrays as input instead";
if (scope == "shared") {
os << "__shared__ ";
+ } else if (scope == "shared.dyn") {
+ os << "extern __shared__ ";
}
}
@@ -703,9 +705,8 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
std::string vid = AllocVarID(op->buffer_var.get());
this->PrintIndent();
- int32_t constant_size = op->constant_allocation_size();
- ICHECK_GT(constant_size, 0) << "Can only handle constant size stack
allocation for now";
std::string scope = GetPtrStorageScope(op->buffer_var);
+ const VarNode* buffer = op->buffer_var.as<VarNode>();
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8)
||
@@ -719,19 +720,28 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
op->dtype == DataType::Int(32))
<< "Accumulator only support half, float and int type for now";
}
- const VarNode* buffer = op->buffer_var.as<VarNode>();
- constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
PrintWmmaScope(scope, op->dtype, buffer, stream);
} else {
PrintStorageScope(scope, stream);
PrintType(op->dtype, stream);
}
- if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
- op->dtype == DataType::Int(1)) &&
- scope == "shared") {
- constant_size = constant_size / (32 / op->dtype.bits());
+
+ if (scope == "shared.dyn") {
+ stream << ' ' << vid << "[];\n";
+ } else {
+ int32_t constant_size = op->constant_allocation_size();
+ ICHECK_GT(constant_size, 0) << "Can only handle constant size stack
allocation for now";
+
+ if (scope.find("wmma.") == 0) {
+ constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
+ }
+ if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
+ op->dtype == DataType::Int(1)) &&
+ scope == "shared") {
+ constant_size = constant_size / (32 / op->dtype.bits());
+ }
+ stream << ' ' << vid << '[' << constant_size << "];\n";
}
- stream << ' ' << vid << '[' << constant_size << "];\n";
RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body);
diff --git a/src/tir/transforms/lower_device_storage_access_info.cc
b/src/tir/transforms/lower_device_storage_access_info.cc
index 829b7d8..eafed83 100644
--- a/src/tir/transforms/lower_device_storage_access_info.cc
+++ b/src/tir/transforms/lower_device_storage_access_info.cc
@@ -67,7 +67,7 @@ class StorageAccessInfoLower : public StmtExprMutator {
StorageScope scope =
StorageScope::Create(op->value.as<StringImmNode>()->value);
StorageEntry e;
e.scope = scope;
- if (scope.tag.length() != 0) {
+ if (scope.tag.length() != 0 && scope.tag != ".dyn") {
e.info = GetMemoryInfo(op->value.as<StringImmNode>()->value);
ICHECK(e.info.defined()) << "Cannot find memory info of " <<
scope.to_string();
}
diff --git a/src/tir/transforms/split_host_device.cc
b/src/tir/transforms/split_host_device.cc
index f01d987..795ae9d 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -33,6 +33,9 @@
#include <unordered_map>
+#include "../../runtime/thread_storage_scope.h"
+#include "ir_utils.h"
+
namespace tvm {
namespace tir {
@@ -89,6 +92,17 @@ class VarUseDefAnalysis : public StmtExprMutator {
Stmt VisitStmt_(const AllocateNode* op) final {
this->HandleDef(op->buffer_var.get());
+ auto storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
+ if (storage_scope.rank == runtime::StorageRank::kShared &&
storage_scope.tag == ".dyn") {
+ ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory
allocation is allowed.";
+ ICHECK_GT(op->extents.size(), 0);
+ dyn_shmem_size_ = op->extents[0];
+ for (size_t i = 1; i < op->extents.size(); ++i) {
+ dyn_shmem_size_ *= op->extents[i];
+ }
+ dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes());
+ use_dyn_shmem_ = true;
+ }
return StmtExprMutator::VisitStmt_(op);
}
@@ -175,6 +189,8 @@ class VarUseDefAnalysis : public StmtExprMutator {
Array<Var> undefined_;
Array<IterVar> thread_axis_;
Array<PrimExpr> thread_extent_;
+ PrimExpr dyn_shmem_size_{0};
+ bool use_dyn_shmem_{false};
std::unordered_map<const VarNode*, int> use_count_;
std::unordered_map<const VarNode*, int> def_count_;
@@ -262,6 +278,10 @@ class HostDeviceSplitter : public StmtMutator {
WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol,
runtime::String(kernel_symbol));
device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias,
Integer(1));
device_func = WithAttr(std::move(device_func), tvm::attr::kTarget,
device_target_);
+ if (m.use_dyn_shmem_) {
+ device_func =
+ WithAttr(std::move(device_func),
tir::attr::kDeviceUseDynSharedMemory, Integer(1));
+ }
(*device_mod_)->Add(GlobalVar(kernel_symbol), device_func);
// generate calls to the device function
@@ -273,6 +293,9 @@ class HostDeviceSplitter : public StmtMutator {
for (PrimExpr ext : m.thread_extent_) {
call_args.push_back(ext);
}
+ if (m.use_dyn_shmem_) {
+ call_args.push_back(m.dyn_shmem_size_);
+ }
return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
call_args));
}
diff --git a/src/tir/transforms/storage_rewrite.cc
b/src/tir/transforms/storage_rewrite.cc
index 613d026..b216b8b 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -512,7 +512,7 @@ class StoragePlanRewriter : public StmtExprMutator {
// try to find merge, for tagged memory
for (size_t i = 0; i < vec.size(); ++i) {
StorageEntry* e = vec[i];
- if (e->scope.tag.length() != 0) {
+ if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") {
ICHECK_NE(e->const_nbits, 0U) << "Special tagged memory must be
const size";
for (size_t j = 0; j < i; ++j) {
if (e->scope == vec[j]->scope) {
@@ -546,7 +546,7 @@ class StoragePlanRewriter : public StmtExprMutator {
make_const(DataType::Int(32), 1),
e->allocs[0]->extents);
e->new_alloc =
Allocate(e->alloc_var, alloc_type, {sz},
e->allocs[0]->condition, Evaluate(0));
- if (e->scope.tag.length() != 0) {
+ if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
@@ -587,7 +587,7 @@ class StoragePlanRewriter : public StmtExprMutator {
combo_size = analyzer_.Simplify(combo_size);
e->new_alloc =
Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(),
Evaluate(0));
- if (e->scope.tag.length() != 0) {
+ if (e->scope.tag.length() != 0 && e->scope.tag != ".dyn") {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
diff --git a/tests/python/unittest/test_tir_ir_builder.py
b/tests/python/unittest/test_tir_ir_builder.py
index 355d3ab..0329134 100644
--- a/tests/python/unittest/test_tir_ir_builder.py
+++ b/tests/python/unittest/test_tir_ir_builder.py
@@ -18,6 +18,7 @@ import tvm
from tvm import te
import numpy as np
import tvm.testing
+from tvm.topi.math import cast
def test_for():
@@ -497,6 +498,62 @@ def test_while_binary_search():
check_target("vulkan", searchsorted_ir_gpu)
[email protected]_gpu
+def test_dyn_shared():
+ n = te.size_var("n")
+ dtype = "float32"
+ A = te.placeholder((n,), name="A")
+
+ def test_device_ir(A, B):
+ n = A.shape[0]
+ ib = tvm.tir.ir_builder.create()
+
+ tx = te.thread_axis("threadIdx.x")
+ ib.scope_attr(tx, "thread_extent", n)
+
+ temp = ib.allocate(dtype, (n,), scope="shared.dyn") # n is symbolic
size
+
+ Aptr = ib.buffer_ptr(A)
+ Bptr = ib.buffer_ptr(B)
+
+ temp[tx] = Aptr[tx]
+ depth = tvm.tir.log2(cast(n, "float32"))
+
+ with ib.for_range(0, depth) as i:
+ ib.emit(tvm.tir.Call(None, "tir.tvm_storage_sync",
tvm.runtime.convert(["shared"])))
+ d = n >> (i + 1)
+ with ib.if_scope(tx < d):
+ temp[tx] += temp[tx + d]
+
+ Bptr[0] = temp[0]
+ return ib.get()
+
+ B = te.extern(
+ (1,),
+ [A],
+ lambda ins, outs: test_device_ir(ins[0], outs[0]),
+ name="reduce",
+ dtype=dtype,
+ )
+ s = te.create_schedule(B.op)
+
+ def check_target(target):
+ if not tvm.testing.device_enabled(target):
+ return
+
+ freduce = tvm.build(s, [A, B], target)
+ dev = tvm.device(target, 0)
+
+ for n in [512, 1024]:
+ a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
+ b = tvm.nd.array(np.zeros(1, dtype=B.dtype), dev)
+ freduce(a, b)
+ tvm.testing.assert_allclose(b.numpy()[0], np.sum(a.numpy()), 1e-4,
1e-4)
+
+ for target in ["cuda", "nvptx"]:
+ check_target(target)
+
+
if __name__ == "__main__":
test_prefetch()
test_if()
@@ -507,3 +564,4 @@ if __name__ == "__main__":
test_while_collatz()
test_while_mandel()
test_while_binary_search()
+ test_dyn_shared()
diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts
index f12837f..226797e 100644
--- a/web/src/webgpu.ts
+++ b/web/src/webgpu.ts
@@ -39,7 +39,7 @@ export async function detectGPUDevice(): Promise<GPUDevice |
undefined | null> {
interface FunctionInfo {
name: string;
arg_types: Array<string>;
- thread_axis_tags: Array<string>;
+ launch_param_tags: Array<string>;
}
/**
@@ -114,8 +114,8 @@ export class WebGPUContext {
const dispatchToDim: Array<number> = [];
- for (let i = 0; i < finfo.thread_axis_tags.length; ++i) {
- const tag: string = finfo.thread_axis_tags[i];
+ for (let i = 0; i < finfo.launch_param_tags.length; ++i) {
+ const tag: string = finfo.launch_param_tags[i];
if (tag.startsWith("blockIdx.")) {
const target: number = tag.charCodeAt(tag.length - 1) -
("x".charCodeAt(0));
assert(target >= 0 && target < 3);