This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ec3171ab7a [REFACTOR][TIR] Tie
AnnotateDeviceRegions/SplitHostDevice/LowerDeviceKernelLaunch together (#19605)
ec3171ab7a is described below
commit ec3171ab7a4c06fff4e9c1e441d28ef4e9a5831b
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue May 26 22:10:54 2026 -0400
[REFACTOR][TIR] Tie
AnnotateDeviceRegions/SplitHostDevice/LowerDeviceKernelLaunch together (#19605)
## Summary
These three passes are logically a single host/device split step;
having intermediaries between them obscures the model and blocks
folding them into one pass. This PR moves each intermediary to the
position its actual ordering constraint allows, so that
`AnnotateDeviceRegions`, `SplitHostDevice`, and
`LowerDeviceKernelLaunch` run consecutively in every pipeline.
## Rationale
- `MergeSharedMemoryAllocations` moves **before**
`AnnotateDeviceRegions`
(the only legal position: `LowerDeviceKernelLaunch` requires at most
one dyn-shmem allocation per kernel, so Merge cannot move past Lower).
- `MakePackedAPI` moves **after** `LowerDeviceKernelLaunch` (Lower's
`kCallingConv = kDeviceKernelLaunch` flag causes `MakePackedAPI` to
correctly skip device kernels; the host body's lowered
`tvm_call_packed` is transparent to `MakePackedAPI`'s subroutine
rewriter).
- `FP8StorageLegalize` / `BF16StorageLegalize` move **after**
`MakePackedAPI` (their `buffer_map.size()==0` ICHECK requires
`MakePackedAPI` to have cleared the map).
Prereq for Phase 2: collapsing the three consecutive passes into a
single `tirx.transform.SplitHostDevice` with three commented regions.
## Test plan
- [x] tests/python/tirx-transform/ target-pass unit tests (25 pass)
- [x]
tests/python/s_tir/transform/test_merge_dynamic_shared_memory_allocations.py
(5 pass)
- [x] tests/python/tirx-transform/test_tir_transform_fp8_legalize.py /
test_tir_transform_bf16_legalize.py (13 pass)
- [x] tests/python/codegen/test_target_codegen_c_host.py /
test_target_codegen_device.py (6 pass including
test_subroutine_call — verifies Risk #2)
- [x] pre-commit run --all-files clean
- [ ] CI: lint / Windows / MacOS
---
python/tvm/s_tir/backend/adreno/pipeline.py | 5 +-
python/tvm/s_tir/pipeline.py | 5 +-
python/tvm/tirx/compilation_pipeline.py | 6 +-
.../transform/merge_shared_memory_allocations.cc | 447 +++++++++++++--------
src/tirx/transform/lower_device_kernel_launch.cc | 91 ++++-
...form_merge_dynamic_shared_memory_allocations.py | 95 ++++-
6 files changed, 434 insertions(+), 215 deletions(-)
diff --git a/python/tvm/s_tir/backend/adreno/pipeline.py
b/python/tvm/s_tir/backend/adreno/pipeline.py
index 85359b1d35..618970b37e 100644
--- a/python/tvm/s_tir/backend/adreno/pipeline.py
+++ b/python/tvm/s_tir/backend/adreno/pipeline.py
@@ -108,14 +108,13 @@ def default_tir_pipeline():
passes.append(s_tir.transform.InjectPTXLDG32())
passes.extend(
[
+ s_tir.transform.MergeSharedMemoryAllocations(),
tirx.transform.AnnotateDeviceRegions(),
tirx.transform.SplitHostDevice(),
- # MergeSharedMemoryAllocations must follow SplitHostDevice.
- s_tir.transform.MergeSharedMemoryAllocations(),
+ tirx.transform.LowerDeviceKernelLaunch(),
tirx.transform.MakePackedAPI(),
tirx.transform.FP8StorageLegalize(),
tirx.transform.BF16StorageLegalize(),
- tirx.transform.LowerDeviceKernelLaunch(),
]
)
mod = tvm.ir.transform.Sequential(passes)(mod)
diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py
index 33a16b381f..a127e43a0e 100644
--- a/python/tvm/s_tir/pipeline.py
+++ b/python/tvm/s_tir/pipeline.py
@@ -108,14 +108,13 @@ def default_s_tir_pipeline():
passes.append(s_tir.transform.InjectPTXLDG32())
passes.extend(
[
+ s_tir.transform.MergeSharedMemoryAllocations(),
tirx.transform.AnnotateDeviceRegions(),
tirx.transform.SplitHostDevice(),
- # MergeSharedMemoryAllocations must follow SplitHostDevice.
- s_tir.transform.MergeSharedMemoryAllocations(),
+ tirx.transform.LowerDeviceKernelLaunch(),
tirx.transform.MakePackedAPI(),
tirx.transform.FP8StorageLegalize(),
tirx.transform.BF16StorageLegalize(),
- tirx.transform.LowerDeviceKernelLaunch(),
]
)
mod = tvm.ir.transform.Sequential(passes)(mod)
diff --git a/python/tvm/tirx/compilation_pipeline.py
b/python/tvm/tirx/compilation_pipeline.py
index 30facc2663..f964f50668 100644
--- a/python/tvm/tirx/compilation_pipeline.py
+++ b/python/tvm/tirx/compilation_pipeline.py
@@ -50,10 +50,10 @@ def default_tir_pipeline():
tirx.transform.AnnotateEntryFunc(),
tirx.transform.AnnotateDeviceRegions(),
tirx.transform.SplitHostDevice(),
+ tirx.transform.LowerDeviceKernelLaunch(),
tirx.transform.MakePackedAPI(),
tirx.transform.FP8StorageLegalize(),
tirx.transform.BF16StorageLegalize(),
- tirx.transform.LowerDeviceKernelLaunch(),
]
)
mod = tvm.ir.transform.Sequential(passes)(mod)
@@ -91,10 +91,10 @@ def tirx_pipeline():
tirx.transform.AnnotateEntryFunc(),
tirx.transform.AnnotateDeviceRegions(),
tirx.transform.SplitHostDevice(),
+ tirx.transform.LowerDeviceKernelLaunch(),
tirx.transform.MakePackedAPI(),
tirx.transform.FP8StorageLegalize(),
tirx.transform.BF16StorageLegalize(),
- tirx.transform.LowerDeviceKernelLaunch(),
]
)
mod = tvm.ir.transform.Sequential(passes)(mod)
@@ -124,8 +124,8 @@ def trn_pipeline():
tirx.transform.AnnotateEntryFunc(),
tirx.transform.AnnotateDeviceRegions(),
tirx.transform.SplitHostDevice(),
- tirx.transform.MakePackedAPI(),
tirx.transform.LowerDeviceKernelLaunch(),
+ tirx.transform.MakePackedAPI(),
]
return tvm.ir.transform.Sequential(passes)(mod)
diff --git a/src/s_tir/transform/merge_shared_memory_allocations.cc
b/src/s_tir/transform/merge_shared_memory_allocations.cc
index c680eb38ac..d1417943c3 100644
--- a/src/s_tir/transform/merge_shared_memory_allocations.cc
+++ b/src/s_tir/transform/merge_shared_memory_allocations.cc
@@ -77,24 +77,26 @@ static int64_t ConstantAllocationSize(const
ffi::Array<PrimExpr>& extents) {
}
/*!
- * \brief collect the mapping from the buffer var to its Buffer
+ * \brief collect the mapping from the buffer var to its Buffer within a
subtree
*/
class AllocateCollector : public StmtExprVisitor {
public:
+ explicit AllocateCollector(bool is_dynamic) : is_dynamic_(is_dynamic) {}
+
void VisitStmt_(const AllocBufferNode* op) final {
- if (IsDynamicSharedMemory(op->buffer->data) ||
IsStaticSharedMemory(op->buffer->data)) {
- if (IsDynamicSharedMemory(op->buffer->data)) {
- dyn_shmem_allocs_[op->buffer->data.get()] = op->buffer;
- } else {
- static_shmem_allocs_[op->buffer->data.get()] = op->buffer;
- }
+ if (is_dynamic_ && IsDynamicSharedMemory(op->buffer->data)) {
+ shmem_allocs_[op->buffer->data.get()] = op->buffer;
+ } else if (!is_dynamic_ && IsStaticSharedMemory(op->buffer->data)) {
+ shmem_allocs_[op->buffer->data.get()] = op->buffer;
}
StmtExprVisitor::VisitStmt_(op);
}
- // The dynamic mapping from the original buffer var to its Buffer
- std::unordered_map<const VarNode*, Buffer> dyn_shmem_allocs_;
- // The static mapping from the original buffer var to its Buffer
- std::unordered_map<const VarNode*, Buffer> static_shmem_allocs_;
+
+ // The mapping from the original buffer var to its Buffer
+ std::unordered_map<const VarNode*, Buffer> shmem_allocs_;
+
+ private:
+ bool is_dynamic_;
};
// Find a linear pattern of storage access
@@ -274,89 +276,131 @@ class SharedMemLinearAccessPatternFinder final : public
StmtExprVisitor {
/*!
* \brief merge the buffers whose live range has no intersection and rewrite
the body
+ *
+ * Uses a scope-stack design: each thread_extent block (kernel launch) gets its
+ * own KernelScope that owns the merged buffer var and all per-launch
bookkeeping.
+ * This correctly handles PrimFuncs with multiple sibling thread_extent blocks.
*/
class SharedMemoryRewriter : public StmtExprMutator {
public:
- explicit SharedMemoryRewriter(const std::unordered_map<const VarNode*,
Buffer>& shmem_allocs,
- bool is_dynamic = true)
- : is_dynamic_{is_dynamic}, shmem_allocs_{shmem_allocs} {
- if (!is_dynamic) {
- merged_buf_var_ = Var("buf_shmem",
PointerType(PrimType(DataType::UInt(8)), "shared"));
- }
- }
+ explicit SharedMemoryRewriter(bool is_dynamic = true) :
is_dynamic_{is_dynamic} {}
+
+ private:
+ using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry;
+
+ struct StorageEntry {
+ // The constant size of the buffer in bits, only used if it is constant
+ uint64_t const_nbits{0};
+ // Allocs that shares this entry.
+ // The inner vector means a "layer"
+ // For example, it we need to allocate C in the memory of A and B:
+ // | A: 4096 bytes | B: 4096 bytes |
+ // | C: 8192 bytes |
+ // Then the allocs = {{A, B}, {C}}
+ std::vector<std::vector<const VarNode*>> allocs;
+ };
+
+ // Event entry in liveness analysis
+ struct EventEntry {
+ // variables we generate
+ std::vector<const VarNode*> gen;
+ // variables we kill
+ std::vector<const VarNode*> kill;
+ };
/*!
- * \brief plan the memory reuse for all the buffer allocated in the statement
- * \param stmt the statement
+ * \brief Per-kernel-launch scope holding all state for one thread_extent
block.
*/
- void PlanReuse(const Stmt& stmt, bool is_dynamic = true) {
- SharedMemLinearAccessPatternFinder finder(is_dynamic);
- finder(stmt);
- this->LivenessAnalysis(finder.linear_seq_);
- this->PlanMemory(finder.linear_seq_);
+ struct KernelScope {
+ // The merged buffer var for THIS kernel launch.
+ Var merged_buf_var;
+ // Total byte size of THIS kernel's merged buffer.
+ PrimExpr merged_alloc_size{0};
+ // Allocations from THIS kernel's subtree.
+ std::unordered_map<const VarNode*, Buffer> shmem_allocs;
+ // Per-buffer byte offset into merged_buf_var.
+ std::unordered_map<const VarNode*, PrimExpr> buffer_byte_offsets;
+ // Buffer-object remap: original Buffer -> merged-data-var Buffer.
+ std::unordered_map<const BufferNode*, Buffer> buffer_remap;
+ // Has any original alloc in this scope been marked volatile?
+ bool has_volatile_alloc{false};
+ // Liveness data (event_map, alloc_map, const_free_map, sym_free_list) —
all per-scope.
+ std::unordered_map<const ffi::Object*, EventEntry> event_map;
+ std::multimap<uint64_t, StorageEntry*> const_free_map;
+ std::list<StorageEntry*> sym_free_list;
+ std::unordered_map<const VarNode*, StorageEntry*> alloc_map;
+ };
+
+ /*!
+ * \brief Create a fresh merged buffer Var for a new kernel scope.
+ * Same name string is fine — Var identity is by pointer, not name.
+ */
+ Var MakeMergedBufferVar() {
+ if (is_dynamic_) {
+ return Var("buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)),
"shared.dyn"));
+ } else {
+ return Var("buf_shmem", PointerType(PrimType(DataType::UInt(8)),
"shared"));
+ }
}
- private:
Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == tirx::attr::thread_extent && !allocated_) {
- // Allocate one dynamic shared memory allocation at the beginning of
thread scope
- int max_layer_num = 0;
- std::vector<const StorageEntry*> all_entry;
- for (const auto& e : const_free_map_) {
- all_entry.push_back(e.second);
- }
- for (const StorageEntry* e : sym_free_list_) {
- all_entry.push_back(e);
- }
- for (const StorageEntry* e : all_entry) {
- max_layer_num = std::max(max_layer_num,
static_cast<int>(e->allocs.size()));
- }
- // calculate align for each layer of each storage entry.
- std::vector<int> align(max_layer_num, 0);
- for (const StorageEntry* e : all_entry) {
- for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
- for (const VarNode* buffer : e->allocs[i]) {
- const Buffer& buf = shmem_allocs_.at(buffer);
- align[i] = std::max(align[i], buf->dtype.bytes());
- }
- }
- }
- // calculate offset for each buffer based on the align of each layer
- for (const StorageEntry* e : all_entry) {
- PrimExpr max_inner_offset = 0;
- for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
- PrimExpr inner_offset = 0;
- for (const VarNode* buffer : e->allocs[i]) {
- const Buffer& buf = shmem_allocs_.at(buffer);
- ffi::Array<PrimExpr> alloc_shape = GetBufferAllocationShape(buf);
- int align_bytes = std::max(align[i], buf->dtype.bytes());
- if (buf->data_alignment > 0) {
- TVM_FFI_ICHECK(buf->data_alignment % align_bytes == 0)
- << "The alignment of the buffer is not a multiple of the
data type size.";
- align_bytes = buf->data_alignment;
- }
- PrimExpr buffer_bytes = alloc_shape[0] * buf->dtype.bytes();
- inner_offset +=
- indexmod(align_bytes - indexmod(merged_alloc_size_ +
inner_offset, align_bytes),
- align_bytes);
- buffer_byte_offsets_[buffer] = merged_alloc_size_ + inner_offset;
- inner_offset += buffer_bytes;
- }
- max_inner_offset = max(max_inner_offset, inner_offset);
- }
- merged_alloc_size_ += max_inner_offset;
+ if (op->attr_key == tirx::attr::thread_extent && !in_thread_env_) {
+ in_thread_env_ = true;
+
+ // 1. Push a fresh scope.
+ scope_stack_.emplace_back();
+ KernelScope& scope = scope_stack_.back();
+ scope.merged_buf_var = MakeMergedBufferVar();
+
+ // 2. Collect shmem allocs that belong to THIS subtree.
+ AllocateCollector collector(is_dynamic_);
+ collector(op->body);
+ scope.shmem_allocs = std::move(collector.shmem_allocs_);
+
+ // Per-scope early bail-out: if this thread_extent block has ≤1 shmem
+ // allocation, there is nothing to merge. Skip liveness analysis,
+ // memory planning, and rewriting entirely.
+ if (scope.shmem_allocs.size() <= 1) {
+ scope_stack_.pop_back();
+ in_thread_env_ = false;
+ return StmtExprMutator::VisitStmt_(op);
}
- allocated_ = true;
- Buffer merged_buf(merged_buf_var_, DataType::UInt(8),
{merged_alloc_size_}, {}, PrimExpr(),
- merged_buf_var_->name_hint, 0, 0,
BufferType::kDefault);
+ // 3. Liveness + reuse plan over this subtree only.
+ // Run the finder on the full AttrStmt (not just op->body) so that
+ // VisitNewScope creates the proper scope pair entry for the
thread_extent.
+ SharedMemLinearAccessPatternFinder finder(is_dynamic_);
+ finder(ffi::GetRef<Stmt>(op));
+ this->LivenessAnalysis(finder.linear_seq_, scope);
+ this->PlanMemory(finder.linear_seq_, scope);
+
+ // 4. Compute byte offsets / merged_alloc_size.
+ this->ComputeOffsets(scope);
+
+ // 5. Recursively mutate the body — reads scope_stack_.back() for all
rewrites.
Stmt visited_body = StmtExprMutator::VisitStmt(op->body);
+
+ in_thread_env_ = false;
+
+ // 6. If this scope has no shmem allocs, skip the wrapper.
+ if (scope.shmem_allocs.empty()) {
+ scope_stack_.pop_back();
+ return AttrStmt(op->node, op->attr_key, op->value, visited_body,
op->span);
+ }
+
+ // 7. Wrap with the merged-buffer AllocBuffer.
+ Buffer merged_buf(scope.merged_buf_var, DataType::UInt(8),
{scope.merged_alloc_size}, {},
+ PrimExpr(), scope.merged_buf_var->name_hint, 0, 0,
BufferType::kDefault);
ffi::Map<ffi::String, ffi::Any> annotations;
- if (has_volatile_alloc_) {
+ if (scope.has_volatile_alloc) {
annotations.Set(tirx::attr::kVolatile, true);
}
Stmt alloc_stmt = AllocBuffer(merged_buf, annotations);
Stmt new_body = SeqStmt::Flatten(alloc_stmt, visited_body);
+
+ // 8. Pop the scope.
+ scope_stack_.pop_back();
+
return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span);
}
return StmtMutator::VisitStmt_(op);
@@ -364,10 +408,17 @@ class SharedMemoryRewriter : public StmtExprMutator {
Stmt VisitStmt_(const AllocBufferNode* op) final {
if (IsAppropriateSharedMemory(op->buffer->data)) {
- if (op->annotations.count(tirx::attr::kVolatile)) {
- has_volatile_alloc_ = true;
+ if (!scope_stack_.empty()) {
+ KernelScope& scope = scope_stack_.back();
+ if (scope.shmem_allocs.count(op->buffer->data.get())) {
+ if (op->annotations.count(tirx::attr::kVolatile)) {
+ scope.has_volatile_alloc = true;
+ }
+ return Evaluate(0);
+ }
}
- return Evaluate(0);
+ // Outside any thread_extent scope — leave as-is.
+ return StmtExprMutator::VisitStmt_(op);
}
return StmtExprMutator::VisitStmt_(op);
}
@@ -392,7 +443,8 @@ class SharedMemoryRewriter : public StmtExprMutator {
template <typename Node>
Node VisitBufferAccess(Node node) {
- if (IsAppropriateSharedMemory(node->buffer->data)) {
+ if (IsAppropriateSharedMemory(node->buffer->data) && !scope_stack_.empty()
&&
+ scope_stack_.back().shmem_allocs.count(node->buffer->data.get())) {
TVM_FFI_ICHECK_EQ(node->indices.size(), 1)
<< "MergeSharedMemoryAllocations expects flat memory buffers, "
<< "and is to be run after "
@@ -409,9 +461,13 @@ class SharedMemoryRewriter : public StmtExprMutator {
}
Buffer GetUpdatedBuffer(Buffer buffer) {
+ if (scope_stack_.empty()) return buffer;
+ KernelScope& scope = scope_stack_.back();
+ if (!scope.shmem_allocs.count(buffer->data.get())) return buffer;
+
auto key = buffer.get();
- auto it = buffer_remap_.find(key);
- if (it != buffer_remap_.end()) {
+ auto it = scope.buffer_remap.find(key);
+ if (it != scope.buffer_remap.end()) {
return it->second;
}
@@ -422,10 +478,10 @@ class SharedMemoryRewriter : public StmtExprMutator {
<< "and is to be run after "
<< "FlattenBuffer";
auto writer = buffer.CopyOnWrite();
- writer->data = merged_buf_var_;
+ writer->data = scope.merged_buf_var;
}
- buffer_remap_[key] = buffer;
+ scope.buffer_remap[key] = buffer;
return buffer;
}
@@ -434,7 +490,8 @@ class SharedMemoryRewriter : public StmtExprMutator {
TVM_FFI_ICHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
Var buffer = Downcast<Var>(op->args[1]);
- if (!IsAppropriateSharedMemory(buffer)) {
+ if (!IsAppropriateSharedMemory(buffer) || scope_stack_.empty() ||
+ !scope_stack_.back().shmem_allocs.count(buffer.get())) {
return StmtExprMutator::VisitExpr_(op);
}
PrimExpr extra_offset = GetBufferOffset(buffer, dtype);
@@ -442,7 +499,8 @@ class SharedMemoryRewriter : public StmtExprMutator {
PrimExpr offset = this->VisitExpr(op->args[2]);
PrimExpr extent = this->VisitExpr(op->args[3]);
return Call(op->dtype, op->op,
- {op->args[0], merged_buf_var_, extra_offset + offset,
extent, op->args[4]});
+ {op->args[0], scope_stack_.back().merged_buf_var,
extra_offset + offset, extent,
+ op->args[4]});
} else if (op->op.same_as(builtin::ptx_cp_async())) {
TVM_FFI_ICHECK((op->args.size() == 5U) || (op->args.size() == 6U));
Var buffer = Downcast<Var>(op->args[0]);
@@ -451,7 +509,8 @@ class SharedMemoryRewriter : public StmtExprMutator {
const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
TVM_FFI_ICHECK(prim_type) << "The buffer should be a pointer to a
primitive type.";
DataType dtype = DataType(prim_type->dtype);
- if (!IsAppropriateSharedMemory(buffer)) {
+ if (!IsAppropriateSharedMemory(buffer) || scope_stack_.empty() ||
+ !scope_stack_.back().shmem_allocs.count(buffer.get())) {
return StmtExprMutator::VisitExpr_(op);
}
PrimExpr extra_offset = GetBufferOffset(buffer, dtype);
@@ -461,21 +520,25 @@ class SharedMemoryRewriter : public StmtExprMutator {
// the correct offset of merged shared buffer.
int index_factor = dtype.bytes();
if (op->args.size() == 5)
- return Call(dtype, op->op,
- {merged_buf_var_, mul(extra_offset + offset,
PrimExpr(index_factor)),
- op->args[2], op->args[3], op->args[4]});
+ return Call(
+ dtype, op->op,
+ {scope_stack_.back().merged_buf_var, mul(extra_offset + offset,
PrimExpr(index_factor)),
+ op->args[2], op->args[3], op->args[4]});
else
- return Call(dtype, op->op,
- {merged_buf_var_, mul(extra_offset + offset,
PrimExpr(index_factor)),
- op->args[2], op->args[3], op->args[4], op->args[5]});
+ return Call(
+ dtype, op->op,
+ {scope_stack_.back().merged_buf_var, mul(extra_offset + offset,
PrimExpr(index_factor)),
+ op->args[2], op->args[3], op->args[4], op->args[5]});
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) {
- auto it = buffer_byte_offsets_.find(buffer_var.get());
- TVM_FFI_ICHECK(it != buffer_byte_offsets_.end());
+ TVM_FFI_ICHECK(!scope_stack_.empty());
+ KernelScope& scope = scope_stack_.back();
+ auto it = scope.buffer_byte_offsets.find(buffer_var.get());
+ TVM_FFI_ICHECK(it != scope.buffer_byte_offsets.end());
return indexdiv(it->second, dtype.bytes());
}
@@ -484,32 +547,12 @@ class SharedMemoryRewriter : public StmtExprMutator {
return is_dynamic_ ? IsDynamicSharedMemory(var) :
IsStaticSharedMemory(var);
}
- using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry;
- struct StorageEntry {
- // The constant size of the buffer in bits, only used if it is constant
- uint64_t const_nbits{0};
- // Allocs that shares this entry.
- // The inner vector means a "layer"
- // For example, it we need to allocate C in the memory of A and B:
- // | A: 4096 bytes | B: 4096 bytes |
- // | C: 8192 bytes |
- // Then the allocs = {{A, B}, {C}}
- std::vector<std::vector<const VarNode*>> allocs;
- };
-
- // Event entry in liveness analysis
- struct EventEntry {
- // variables we generate
- std::vector<const VarNode*> gen;
- // variables we kill
- std::vector<const VarNode*> kill;
- };
-
/*!
* \brief Liveness analysis to find gen and kill point of each variable.
* \param seq the linear pattern of storage access
+ * \param scope the kernel scope to write results into
*/
- void LivenessAnalysis(const std::vector<StmtEntry>& seq) {
+ void LivenessAnalysis(const std::vector<StmtEntry>& seq, KernelScope& scope)
{
// find kill point, do a reverse linear scan.
std::unordered_set<const VarNode*> touched;
for (size_t i = seq.size(); i != 0; --i) {
@@ -517,7 +560,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
for (const VarNode* buffer : s.touched) {
if (!touched.count(buffer)) {
touched.insert(buffer);
- event_map_[s.stmt].kill.push_back(buffer);
+ scope.event_map[s.stmt].kill.push_back(buffer);
}
}
}
@@ -530,7 +573,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
for (const VarNode* buffer : s.touched) {
if (!touched.count(buffer)) {
touched.insert(buffer);
- event_map_[s.stmt].gen.push_back(buffer);
+ scope.event_map[s.stmt].gen.push_back(buffer);
}
}
}
@@ -539,12 +582,13 @@ class SharedMemoryRewriter : public StmtExprMutator {
/*!
* \brief Memory plan algorithm
* \param seq the linear pattern of storage access
+ * \param scope the kernel scope to write results into
*/
- void PlanMemory(const std::vector<StmtEntry>& seq) {
+ void PlanMemory(const std::vector<StmtEntry>& seq, KernelScope& scope) {
std::unordered_set<const VarNode*> inplace_flag;
for (size_t i = 0; i < seq.size(); ++i) {
- auto it = event_map_.find(seq[i].stmt);
+ auto it = scope.event_map.find(seq[i].stmt);
// scope_pair_offset <= 0 means it is either
// - leaf stmt(offset = 0)
// - end of scope(offset < 0)
@@ -553,30 +597,84 @@ class SharedMemoryRewriter : public StmtExprMutator {
return seq[i].scope_pair_offset == 0 &&
std::find(it->second.gen.begin(), it->second.gen.end(), var) !=
it->second.gen.end();
};
- if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
+ if (it != scope.event_map.end() && seq[i].scope_pair_offset <= 0) {
for (const VarNode* var : it->second.kill) {
- if (!is_leaf_alloc(var)) this->Free(var);
+ if (!is_leaf_alloc(var)) this->Free(var, scope);
}
}
// scope_pair_offset >= 0 means it is either
// - leaf stmt(offset = 0)
// - beginning of scope(offset < 0)
// In both cases, we need to handle the gen event correctly
- if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
+ if (it != scope.event_map.end() && seq[i].scope_pair_offset >= 0) {
for (const VarNode* var : it->second.gen) {
- TVM_FFI_ICHECK(shmem_allocs_.count(var));
- const Buffer& buf = shmem_allocs_.at(var);
- StorageEntry* dst_entry = FindAlloc(buf);
- alloc_map_[var] = dst_entry;
+ TVM_FFI_ICHECK(scope.shmem_allocs.count(var));
+ const Buffer& buf = scope.shmem_allocs.at(var);
+ StorageEntry* dst_entry = FindAlloc(buf, scope);
+ scope.alloc_map[var] = dst_entry;
}
}
- if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
+ if (it != scope.event_map.end() && seq[i].scope_pair_offset <= 0) {
for (const VarNode* var : it->second.kill) {
- if (is_leaf_alloc(var)) this->Free(var);
+ if (is_leaf_alloc(var)) this->Free(var, scope);
+ }
+ }
+ }
+ }
+
+ /*!
+ * \brief Compute byte offsets for all entries in the scope after PlanMemory.
+ * \param scope the kernel scope whose offset map to fill
+ */
+ void ComputeOffsets(KernelScope& scope) {
+ int max_layer_num = 0;
+ std::vector<const StorageEntry*> all_entry;
+ for (const auto& e : scope.const_free_map) {
+ all_entry.push_back(e.second);
+ }
+ for (const StorageEntry* e : scope.sym_free_list) {
+ all_entry.push_back(e);
+ }
+ for (const StorageEntry* e : all_entry) {
+ max_layer_num = std::max(max_layer_num,
static_cast<int>(e->allocs.size()));
+ }
+ // calculate align for each layer of each storage entry.
+ std::vector<int> align(max_layer_num, 0);
+ for (const StorageEntry* e : all_entry) {
+ for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
+ for (const VarNode* buffer : e->allocs[i]) {
+ const Buffer& buf = scope.shmem_allocs.at(buffer);
+ align[i] = std::max(align[i], buf->dtype.bytes());
}
}
}
+ // calculate offset for each buffer based on the align of each layer
+ for (const StorageEntry* e : all_entry) {
+ PrimExpr max_inner_offset = 0;
+ for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
+ PrimExpr inner_offset = 0;
+ for (const VarNode* buffer : e->allocs[i]) {
+ const Buffer& buf = scope.shmem_allocs.at(buffer);
+ ffi::Array<PrimExpr> alloc_shape = GetBufferAllocationShape(buf);
+ int align_bytes = std::max(align[i], buf->dtype.bytes());
+ if (buf->data_alignment > 0) {
+ TVM_FFI_ICHECK(buf->data_alignment % align_bytes == 0)
+ << "The alignment of the buffer is not a multiple of the data
type size.";
+ align_bytes = buf->data_alignment;
+ }
+ PrimExpr buffer_bytes = alloc_shape[0] * buf->dtype.bytes();
+ inner_offset +=
+ indexmod(align_bytes - indexmod(scope.merged_alloc_size +
inner_offset, align_bytes),
+ align_bytes);
+ scope.buffer_byte_offsets[buffer] = scope.merged_alloc_size +
inner_offset;
+ inner_offset += buffer_bytes;
+ }
+ max_inner_offset = max(max_inner_offset, inner_offset);
+ }
+ scope.merged_alloc_size = scope.merged_alloc_size + max_inner_offset;
+ }
}
+
/*!
* \brief Allocate new storage entry.
* \param buf the buffer object
@@ -590,12 +688,14 @@ class SharedMemoryRewriter : public StmtExprMutator {
entry->const_nbits = const_nbits;
return entry;
}
+
/*!
* \brief find the storage entry in the free list for the buffer
* \param buf the buffer object
+ * \param scope the kernel scope whose free lists to search
* \return the storage entry
*/
- StorageEntry* FindAlloc(const Buffer& buf) {
+ StorageEntry* FindAlloc(const Buffer& buf, KernelScope& scope) {
// skip plan for local variable,
// compiler can do a better job with register allocation.
const uint64_t match_range = 16;
@@ -611,17 +711,17 @@ class SharedMemoryRewriter : public StmtExprMutator {
if (const_nbits != 0) {
// constant allocation.
- auto begin = const_free_map_.lower_bound(0);
- auto mid = const_free_map_.lower_bound(const_nbits);
- auto end = const_free_map_.upper_bound(const_nbits * match_range);
+ auto begin = scope.const_free_map.lower_bound(0);
+ auto mid = scope.const_free_map.lower_bound(const_nbits);
+ auto end = scope.const_free_map.upper_bound(const_nbits * match_range);
// Start looking at the buffer that is bigger than the required size
first.
// If we find one, directly allocate the buffer in its location and
remove its entry in the
// free list
for (auto it = mid; it != end; ++it) {
StorageEntry* e = it->second;
e->const_nbits = std::max(const_nbits, e->const_nbits);
- const_free_map_.erase(it);
- it->second->allocs.push_back({buf->data.get()});
+ scope.const_free_map.erase(it);
+ e->allocs.push_back({buf->data.get()});
return e;
}
// Then start looking at smaller buffers.
@@ -654,16 +754,16 @@ class SharedMemoryRewriter : public StmtExprMutator {
e->const_nbits = std::max(const_nbits, mem_ct);
e->allocs = reuse_allocs;
for (auto it : delete_it) {
- const_free_map_.erase(it);
+ scope.const_free_map.erase(it);
}
return e;
}
} else {
// if its symbolic allocation, just arbitrarily choose one entry to fit
in because we don't
// know its actual size
- for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it)
{
+ for (auto it = scope.sym_free_list.begin(); it !=
scope.sym_free_list.end(); ++it) {
StorageEntry* e = *it;
- sym_free_list_.erase(it);
+ scope.sym_free_list.erase(it);
return e;
}
}
@@ -673,10 +773,11 @@ class SharedMemoryRewriter : public StmtExprMutator {
/*!
* \brief add the storage entry to the buffer var into the free list.
* \param var the buffer var
+ * \param scope the kernel scope whose free lists to update
*/
- void Free(const VarNode* var) {
- auto it = alloc_map_.find(var);
- TVM_FFI_ICHECK(it != alloc_map_.end());
+ void Free(const VarNode* var, KernelScope& scope) {
+ auto it = scope.alloc_map.find(var);
+ TVM_FFI_ICHECK(it != scope.alloc_map.end());
StorageEntry* e = it->second;
TVM_FFI_ICHECK_NE(e->allocs.size(), 0U);
@@ -685,51 +786,41 @@ class SharedMemoryRewriter : public StmtExprMutator {
// normal free.
if (e->const_nbits != 0) {
- const_free_map_.insert({e->const_nbits, e});
+ scope.const_free_map.insert({e->const_nbits, e});
} else {
- sym_free_list_.push_back(e);
+ scope.sym_free_list.push_back(e);
}
}
+
// Whether enable dynamic analysis.
bool is_dynamic_{true};
- // The var for the merged buffer
- Var merged_buf_var_{"buf_dyn_shmem",
PointerType(PrimType(DataType::UInt(8)), "shared.dyn")};
- // The mapping from the original buffer var to its Buffer
- std::unordered_map<const VarNode*, Buffer> shmem_allocs_;
- // The size of the merged buffer
- PrimExpr merged_alloc_size_{0};
- // The mapping from the original buffer var to its offset in the merged
buffer
- std::unordered_map<const VarNode*, PrimExpr> buffer_byte_offsets_;
- // The mapping from the original buffer objects to their location in the
merged buffer.
- std::unordered_map<const BufferNode*, Buffer> buffer_remap_;
- // The flag indicating whether the merged buffer has been allocated
- bool allocated_{false};
- // Whether any original shared memory allocation had the volatile annotation
- bool has_volatile_alloc_{false};
- // Locations of free ops.
- std::unordered_map<const ffi::Object*, EventEntry> event_map_;
- // constant size free map.
- std::multimap<uint64_t, StorageEntry*> const_free_map_;
- // symbolic free list, for non constant items.
- std::list<StorageEntry*> sym_free_list_;
- // The allocation assign map
- std::unordered_map<const VarNode*, StorageEntry*> alloc_map_;
- /*! \brief allocator of all the StorageEntry*/
+ // Whether already inside a thread_extent (outermost only).
+ bool in_thread_env_{false};
+ // Stack of per-kernel-launch scopes. Pushed on thread_extent entry, popped
on exit.
+ std::vector<KernelScope> scope_stack_;
+ /*! \brief allocator of all the StorageEntry (shared across all scopes) */
support::Arena arena_;
};
Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem) {
- AllocateCollector collector;
- collector(stmt);
- if (collector.dyn_shmem_allocs_.size() > 1) {
- SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_);
- rewriter.PlanReuse(stmt);
- stmt = rewriter(std::move(stmt));
+ // Function-level early-out: skip the rewriter entirely if the PrimFunc
+ // has ≤1 dynamic shared-memory allocation (nothing to merge).
+ {
+ AllocateCollector dyn_probe(/*is_dynamic=*/true);
+ dyn_probe(stmt);
+ if (dyn_probe.shmem_allocs_.size() > 1) {
+ SharedMemoryRewriter dyn_rewriter(/*is_dynamic=*/true);
+ stmt = dyn_rewriter(std::move(stmt));
+ }
}
- if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) {
- SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false);
- rewriter.PlanReuse(stmt, false);
- stmt = rewriter(std::move(stmt));
+ if (merge_static_smem) {
+ // Similarly skip the static rewriter if there is ≤1 static shmem alloc.
+ AllocateCollector static_probe(/*is_dynamic=*/false);
+ static_probe(stmt);
+ if (static_probe.shmem_allocs_.size() > 1) {
+ SharedMemoryRewriter static_rewriter(/*is_dynamic=*/false);
+ stmt = static_rewriter(std::move(stmt));
+ }
}
return stmt;
}
diff --git a/src/tirx/transform/lower_device_kernel_launch.cc
b/src/tirx/transform/lower_device_kernel_launch.cc
index 9b38c4d629..af30af6bfb 100644
--- a/src/tirx/transform/lower_device_kernel_launch.cc
+++ b/src/tirx/transform/lower_device_kernel_launch.cc
@@ -213,6 +213,21 @@ class DeviceKernelMutator : public StmtExprMutator {
auto it = device_info_map_.find(gvar.get());
TVM_FFI_ICHECK(it != device_info_map_.end());
current_target_ = it->second.target;
+ // Track whether the caller is a host function (i.e. its target
+ // still has a host attached) and capture its host target. The
+ // same-target shortcut at the call site is only safe when caller
+ // and callee are both device-resident; a host caller must take
+ // the kernel-launch path even if Target::WithoutHost() makes the
+ // strings match. Conversely, a host caller invoking another host
+ // helper (e.g. a same-target subroutine that SplitHostDevice
+ // emitted on the host side) should compare against the host
+ // target, not the device target stripped by WithoutHost().
+ auto full_target = func->GetAttr<Target>(tvm::attr::kTarget).value();
+ if (full_target->GetHost().defined()) {
+ current_caller_host_target_ = full_target->GetHost().value();
+ } else {
+ current_caller_host_target_ = std::nullopt;
+ }
auto body = VisitStmt(func->body);
if (!body.same_as(func->body)) {
@@ -220,6 +235,7 @@ class DeviceKernelMutator : public StmtExprMutator {
}
current_target_ = std::nullopt;
+ current_caller_host_target_ = std::nullopt;
return func;
}
@@ -272,29 +288,59 @@ class DeviceKernelMutator : public StmtExprMutator {
<< gvar->name_hint << " did not appear within the IRModule";
const KernelInfo& dev_info = it->second;
- auto caller_target = current_target_.value();
auto callee_target = dev_info.target;
- bool same_target = caller_target->str() == callee_target->str();
- if (same_target) {
- // Calls within the same target may be handled at codegen time
- // as internal subroutine calls.
- return node;
- }
+ // A callee with non-empty launch_params has thread_extent
+ // bindings in its body, i.e. it is a real device kernel that
+ // must be invoked via a kernel-launch ABI. Conversely a callee
+ // with empty launch_params is a plain subroutine (host helper
+ // or intra-device helper) and is never invoked via kernel launch.
+ bool callee_is_kernel = dev_info.launch_params.size() > 0;
+ bool caller_is_host = current_caller_host_target_.has_value();
+
+ // For host callers, comparisons against the callee target must
+ // use the caller's *host* target, not the device target stripped
+ // by WithoutHost(). This handles two cases that the device-side
+ // comparison gets wrong:
+ // 1. A host caller invoking a real device kernel whose
+ // WithoutHost() target happens to match (e.g. kernel target
+ // "cuda" matches "cuda+host=c" after stripping host). Must
+ // go through kernel launch, not the same-target shortcut.
+ // 2. A host caller invoking another host helper with a
+ // different host target (e.g. SplitHostDevice emits an
+ // "add_host" with target "c" while the host body still
+ // carries "cuda+host=c"). Must go through call_extern (or
+ // same-target subroutine), not kernel launch.
+ auto caller_target =
+ caller_is_host ? current_caller_host_target_.value() :
current_target_.value();
+
+ // A host caller invoking a real device kernel must always go
+ // through the kernel-launch ABI, regardless of any same-target /
+ // same-device-type coincidence.
+ bool force_kernel_launch = callee_is_kernel && caller_is_host;
+
+ if (!force_kernel_launch) {
+ bool same_target = caller_target->str() == callee_target->str();
+ if (same_target) {
+ // Calls within the same target may be handled at codegen time
+ // as internal subroutine calls.
+ return node;
+ }
- bool same_device_type =
- caller_target->GetTargetDeviceType() ==
callee_target->GetTargetDeviceType();
- if (same_device_type) {
- // Calls to another target using the same device (e.g. LLVM
- // calling a custom TIRToRuntime target) do not require a kernel
- // launch, but need to be replaced with call_extern.
- extern_function_call_.insert(gvar);
- ffi::Array<PrimExpr> args;
- args.push_back(StringImm(gvar->name_hint));
- for (const auto& arg : node->args) {
- args.push_back(arg);
+ bool same_device_type =
+ caller_target->GetTargetDeviceType() ==
callee_target->GetTargetDeviceType();
+ if (same_device_type) {
+ // Calls to another target using the same device (e.g. LLVM
+ // calling a custom TIRToRuntime target) do not require a kernel
+ // launch, but need to be replaced with call_extern.
+ extern_function_call_.insert(gvar);
+ ffi::Array<PrimExpr> args;
+ args.push_back(StringImm(gvar->name_hint));
+ for (const auto& arg : node->args) {
+ args.push_back(arg);
+ }
+ return Call(node->dtype, builtin::call_extern(), args);
}
- return Call(node->dtype, builtin::call_extern(), args);
}
TVM_FFI_ICHECK(dev_info.launch_params.defined())
@@ -336,6 +382,13 @@ class DeviceKernelMutator : public StmtExprMutator {
}
ffi::Optional<Target> current_target_;
+ // The host target of the caller currently being rewritten, if the
+ // caller is a host function (its kTarget has a host attached).
+ // Used both to detect that the caller is a host function and to
+ // compare against the callee target on the host side, so that
+ // host-to-host subroutine calls are not misrouted through the
+ // device kernel-launch ABI.
+ ffi::Optional<Target> current_caller_host_target_;
std::unordered_map<const GlobalVarNode*, KernelInfo> device_info_map_;
std::unordered_set<const GlobalVarNode*> device_kernel_launch_;
std::unordered_set<const GlobalVarNode*> extern_function_call_;
diff --git
a/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py
b/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py
index ca7d1de7c4..b09c1fd796 100644
---
a/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py
+++
b/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py
@@ -254,23 +254,100 @@ def test_async_copy():
class Before:
@T.prim_func(s_tir=True)
def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,),
"float32")):
+ threadIdx_x = T.launch_thread("threadIdx.x", 128)
A_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn")
B_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn")
- threadIdx_x = T.launch_thread("threadIdx.x", 128)
T.ptx.cp_async("float32", A_sh.data, threadIdx_x, A.data,
threadIdx_x, 512)
T.ptx.cp_async("float32", B_sh.data, threadIdx_x, B.data,
threadIdx_x, 512)
After = transform(Before)
- # The pass merges shared.dyn allocations but DeclBuffer nodes from the
original
- # allocations remain with remapped data vars. The output can't be precisely
- # represented in TVMScript due to same-name var constraints, so we verify
- # key properties instead of exact structural equality.
+ # The pass merges shared.dyn allocations. A_sh and B_sh are accessed
+ # sequentially inside the thread_extent with non-overlapping lifetimes,
+ # so the liveness analysis allows reuse — both fit in 512 bytes
+ # (= 128 elements * 4 bytes).
script = After["main"].script()
- # Verify merged allocation (1024 bytes = 128*4 + 128*4)
- assert '"uint8"' in script and '"shared.dyn"' in script and "(1024,)" in
script
- # Verify cp_async uses correct byte offsets
+ # Verify merged allocation (512 bytes - A_sh and B_sh can be reused)
+ assert '"uint8"' in script and '"shared.dyn"' in script and "(512,)" in
script
+ # Verify cp_async uses the merged buffer
+ assert "buf_dyn_shmem" in script
assert "threadIdx_x * 4" in script
- assert "(128 + threadIdx_x) * 4" in script
+
+
+def test_multi_thread_extent_blocks():
+ """Each thread_extent block must get its own merged buffer.
+
+ Reproduces the scoping bug from PR #19605: a single PrimFunc
+ with two sibling thread_extent regions, each containing its
+ own shared.dyn allocations. The merged buffer must be allocated
+ inside each kernel body — not just the first.
+ """
+ transform = tvm.s_tir.transform.MergeSharedMemoryAllocations()
+
+ @I.ir_module(check_well_formed=False)
+ class Before:
+ @T.prim_func(s_tir=True, check_well_formed=False)
+ def main(
+ X: T.Buffer((128,), "float32"),
+ Y: T.Buffer((128,), "float32"),
+ ):
+ X_flat = T.decl_buffer(128, data=X.data)
+ Y_flat = T.decl_buffer(128, data=Y.data)
+
+ # First kernel launch
+ tx0 = T.env_thread("threadIdx.x")
+ with T.attr(tx0, "thread_extent", 128):
+ A_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn")
+ B_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn")
+ A_sh[tx0] = X_flat[tx0]
+ B_sh[tx0] = A_sh[tx0]
+ X_flat[tx0] = B_sh[tx0]
+
+ # Second kernel launch — must NOT see kernel #0's merged buffer.
+ tx1 = T.env_thread("threadIdx.x")
+ with T.attr(tx1, "thread_extent", 128):
+ C_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn")
+ D_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn")
+ C_sh[tx1] = Y_flat[tx1]
+ D_sh[tx1] = C_sh[tx1]
+ Y_flat[tx1] = D_sh[tx1]
+
+ After = transform(Before)
+ script = After["main"].script()
+
+ # Two merged allocations — one per thread_extent body.
+ # Each of the four original 128-float32 buffers (A_sh, B_sh, C_sh, D_sh)
+ # gets merged within its own kernel scope.
+ assert script.count("shared.dyn") >= 2, (
+ "Expected at least two shared.dyn allocations (one per kernel)"
+ )
+ assert script.count("alloc_buffer") >= 2, (
+ "Expected at least two alloc_buffer nodes (one merged buf per kernel)"
+ )
+
+ # Both thread_extent blocks must contain their own merged buffer —
+ # they must NOT share the same buf_dyn_shmem variable.
+ # Structurally verify that the first kernel's body accesses are
+ # not rewritten to the second kernel's buf_dyn_shmem (and vice versa).
+ first_block = script.split("with T.attr(tx1")[0]
+ second_block = script.split("with T.attr(tx1")[1] if "tx1" in script else
""
+ assert "buf_dyn_shmem" in first_block, "Kernel 1 must have a merged buffer"
+ if second_block:
+ assert "buf_dyn_shmem" in second_block, "Kernel 2 must have a merged
buffer"
+
+ # End-to-end: post-merge IR must remain well-formed through
+ # the host/device split — this is the exact ordering from
+ # PR #19605 that triggers the scoping bug.
+ target = tvm.target.Target("llvm")
+ mod_with_target = tvm.IRModule({"main": After["main"].with_attr({"target":
target})})
+ split = tvm.transform.Sequential(
+ [
+ tvm.tirx.transform.AnnotateDeviceRegions(),
+ tvm.tirx.transform.SplitHostDevice(),
+ ]
+ )
+ # If kernel #1 referenced an undefined buf_dyn_shmem, this
+ # would raise during well-formedness checking inside SplitHostDevice.
+ split(mod_with_target)
if __name__ == "__main__":