This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 969fad363b [TIR] Add VisitBufferDef/VisitBufferUse to base
StmtVisitor/StmtMutator (#18873)
969fad363b is described below
commit 969fad363be13375a4f4ecbc6a5fb2030e8e3f41
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Mar 4 22:06:03 2026 -0500
[TIR] Add VisitBufferDef/VisitBufferUse to base StmtVisitor/StmtMutator
(#18873)
---
include/tvm/tir/stmt_functor.h | 42 +++-
src/s_tir/transform/renew_defs.cc | 105 +++------
src/tir/analysis/var_use_def_analysis.cc | 50 ++--
src/tir/analysis/var_use_def_analysis.h | 10 +-
src/tir/analysis/verify_well_formed.cc | 6 +-
src/tir/ir/data_type_rewriter.cc | 70 ++----
src/tir/ir/data_type_rewriter.h | 8 +-
src/tir/ir/specialize.cc | 31 +--
src/tir/ir/stmt_functor.cc | 251 ++++++++++-----------
src/tir/ir/tir_visitor_with_path.cc | 21 +-
src/tir/ir/tir_visitor_with_path.h | 8 +-
src/tir/transform/ir_utils.cc | 10 +
src/tir/transform/simplify.cc | 13 ++
tests/cpp/ir_functor_test.cc | 50 ++--
.../test_tir_analysis_undefined_vars.py | 93 ++++++++
.../tir-transform/test_tir_transform_simplify.py | 45 +++-
16 files changed, 451 insertions(+), 362 deletions(-)
diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h
index 884f207b97..e86c6bb125 100644
--- a/include/tvm/tir/stmt_functor.h
+++ b/include/tvm/tir/stmt_functor.h
@@ -144,6 +144,20 @@ class TVM_DLL StmtVisitor : protected
StmtFunctor<void(const Stmt&)> {
* and redirect Visit to ExprMutator::VisitExpr(Expr)
*/
virtual void VisitExpr(const PrimExpr& e) {}
+ /*!
+ * \brief Visit buffer at definition site (AllocBuffer, DeclBuffer, SBlock
alloc_buffers).
+ * Visits buffer shape, strides, elem_offset via VisitExpr.
+ * \param buffer The buffer being defined.
+ * \param alloc_data If true, the buffer's data pointer is a new allocation
(AllocBuffer);
+ * if false, data references an existing variable (DeclBuffer).
+ */
+ virtual void VisitBufferDef(const Buffer& buffer, bool alloc_data);
+ /*!
+ * \brief Visit buffer at use site (BufferStore, BufferLoad, SBlock
reads/writes).
+ * By default, this is a no-op, as buffer fields (shape, strides,
elem_offset)
+ * are visited at their definition site.
+ */
+ virtual void VisitBufferUse(const Buffer& buffer);
// statement visitor
void VisitStmt_(const AttrStmtNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
@@ -179,6 +193,8 @@ class TVM_DLL StmtMutator : protected
StmtFunctor<Stmt(const Stmt&)> {
}
protected:
+ /*! \brief Map from old buffer to new buffer, populated by VisitBufferDef. */
+ ffi::Map<Buffer, Buffer> buffer_remap_;
// We perform copy on write optimizations on the StmtMutator
// so that an unique copy of parent can be mutated inplace
// when some of its children changed.
@@ -240,6 +256,22 @@ class TVM_DLL StmtMutator : protected
StmtFunctor<Stmt(const Stmt&)> {
* and redirect Mutate to ExprMutator::Mutate(Expr)
*/
virtual PrimExpr VisitExpr(const PrimExpr& e) { return e; }
+ /*!
+ * \brief Visit buffer at definition site. Visits shape/strides/elem_offset
via VisitExpr.
+ * If any field changes, creates a new buffer and records it in
buffer_remap_.
+ * \param buffer The buffer being defined.
+ * \param alloc_data If true, the buffer's data pointer is a new allocation
(AllocBuffer);
+ * if false, data references an existing variable (DeclBuffer).
+ * \return The (possibly new) buffer.
+ */
+ virtual Buffer VisitBufferDef(const Buffer& buffer, bool alloc_data);
+ /*!
+ * \brief Visit buffer at use site (BufferStore, BufferLoad, SBlock
reads/writes).
+ * By default, returns the remapped buffer from buffer_remap_ if exists,
otherwise
+ * returns the original buffer. Buffer fields are visited at their
definition site.
+ * \return The (possibly remapped) buffer.
+ */
+ virtual Buffer VisitBufferUse(const Buffer& buffer);
// statement visitor
Stmt VisitStmt_(const AttrStmtNode* op) override;
Stmt VisitStmt_(const IfThenElseNode* op) override;
@@ -276,31 +308,35 @@ class TVM_DLL StmtMutator : protected
StmtFunctor<Stmt(const Stmt&)> {
/*!
* \brief Visitor that recursively visit stmts and exprs on them.
*/
-class StmtExprVisitor : public StmtVisitor, public ExprVisitor {
+class StmtExprVisitor : public ExprVisitor, public StmtVisitor {
public:
using StmtVisitor::operator();
using ExprVisitor::operator();
protected:
using ExprVisitor::VisitExpr;
+ using ExprVisitor::VisitExpr_;
using StmtVisitor::VisitStmt;
void VisitExpr(const PrimExpr& e) override { return
ExprVisitor::VisitExpr(e); }
+ void VisitExpr_(const BufferLoadNode* op) override;
};
/*!
* \brief Mutator that recursively mutates stmts and exprs on them.
*/
-class StmtExprMutator : public StmtMutator, public ExprMutator {
+class StmtExprMutator : public ExprMutator, public StmtMutator {
public:
using StmtMutator::operator();
using ExprMutator::operator();
protected:
using ExprMutator::VisitExpr;
- using StmtMutator::VisitExpr;
+ using ExprMutator::VisitExpr_;
+ using StmtMutator::VisitStmt;
PrimExpr VisitExpr(const PrimExpr& e) override { return
ExprMutator::VisitExpr(e); }
+ PrimExpr VisitExpr_(const BufferLoadNode* op) override;
};
/*!
diff --git a/src/s_tir/transform/renew_defs.cc
b/src/s_tir/transform/renew_defs.cc
index f48cfff701..224fccbadb 100644
--- a/src/s_tir/transform/renew_defs.cc
+++ b/src/s_tir/transform/renew_defs.cc
@@ -71,7 +71,7 @@ class RenewDefMutator : public StmtExprMutator {
if (param->dtype.is_handle()) {
const Buffer& buffer = func->buffer_map.at(param);
Var new_param = Downcast<Var>(generator.VisitExpr(param));
- Buffer new_buffer = generator.VisitBuffer(buffer, true);
+ Buffer new_buffer = generator.DefineBuffer(buffer);
buffer_map.Set(new_param, new_buffer);
}
}
@@ -102,40 +102,24 @@ class RenewDefMutator : public StmtExprMutator {
STMT_REGENERATE_VAR_DEF(LetStmtNode, var);
STMT_REGENERATE_VAR_DEF(ForNode, loop_var);
- Stmt VisitStmt_(const AllocBufferNode* op) final {
- Buffer new_buffer = VisitBuffer(op->buffer, /*define=*/true);
- Stmt body = this->VisitStmt(op->body);
- if (new_buffer.same_as(op->buffer) && body.same_as(op->body)) {
- return ffi::GetRef<Stmt>(op);
- } else {
- auto n = ffi::make_object<AllocBufferNode>(*op);
- n->buffer = std::move(new_buffer);
- n->body = std::move(body);
- return Stmt(n);
- }
+ // Override VisitBufferDef to create fresh buffer copies at definition sites
+ // (AllocBuffer, DeclBuffer, SBlock alloc_buffers, match_buffers)
+ Buffer VisitBufferDef(const Buffer& buffer, bool alloc_data) final {
+ return DefineBuffer(buffer);
}
- Stmt VisitStmt_(const DeclBufferNode* op) final {
- Buffer new_buffer = VisitBuffer(op->buffer, /*define=*/true);
- Stmt body = this->VisitStmt(op->body);
- if (new_buffer.same_as(op->buffer) && body.same_as(op->body)) {
- return ffi::GetRef<Stmt>(op);
- } else {
- auto n = ffi::make_object<DeclBufferNode>(*op);
- n->buffer = std::move(new_buffer);
- n->body = std::move(body);
- return Stmt(n);
- }
- }
+ // Override VisitBufferUse to remap buffers at use sites
+ // (BufferStore, BufferLoad, SBlock reads/writes)
+ Buffer VisitBufferUse(const Buffer& buffer) final { return
UseOrRemapBuffer(buffer); }
Stmt VisitStmt_(const SBlockNode* op) final {
// Step 0. Re-define Itervars
ffi::Array<IterVar> iter_vars =
op->iter_vars.Map(std::bind(&RenewDefMutator::VisitIterVar, this,
std::placeholders::_1));
- // Step 1. Re-define buffers allocate under the block
- ffi::Array<Buffer> alloc_buffers = op->alloc_buffers.Map(
- std::bind(&RenewDefMutator::VisitBuffer, this, std::placeholders::_1,
/*define=*/true));
+ // Step 1. Re-define buffers allocated under the block
+ ffi::Array<Buffer> alloc_buffers =
+ op->alloc_buffers.Map([this](const Buffer& buf) { return
this->DefineBuffer(buf); });
// Step 2. Re-define match_buffers
ffi::Array<MatchBufferRegion> match_buffers = op->match_buffers.Map(
@@ -167,34 +151,6 @@ class RenewDefMutator : public StmtExprMutator {
return Stmt(n);
}
- Stmt VisitStmt_(const BufferStoreNode* op) final {
- Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<BufferStoreNode>();
- TVM_FFI_ICHECK(op != nullptr);
- Buffer buffer = VisitDeclOrRemapBuffer(op->buffer);
- if (buffer.same_as(op->buffer)) {
- return stmt;
- } else {
- auto n = ffi::make_object<BufferStoreNode>(*op);
- n->buffer = std::move(buffer);
- return BufferStore(n);
- }
- }
-
- PrimExpr VisitExpr_(const BufferLoadNode* op) final {
- PrimExpr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<BufferLoadNode>();
- TVM_FFI_ICHECK(op != nullptr);
- Buffer buffer = VisitDeclOrRemapBuffer(op->buffer);
- if (buffer.same_as(op->buffer)) {
- return expr;
- } else {
- auto n = ffi::make_object<BufferLoadNode>(*op);
- n->buffer = std::move(buffer);
- return BufferLoad(n);
- }
- }
-
private:
Var ReDefineVar(const Var& var) {
Var new_var = Var(ffi::make_object<VarNode>(*var.get()));
@@ -208,12 +164,11 @@ class RenewDefMutator : public StmtExprMutator {
remap_.Set(source, target);
}
- Buffer VisitBuffer(const Buffer& buffer, bool define = false) {
+ Buffer DefineBuffer(const Buffer& buffer) {
auto it = remap_.find(buffer);
if (it != remap_.end()) {
return Downcast<Buffer>((*it).second);
}
- TVM_FFI_ICHECK(define);
auto redefine_if_is_var = [this](const PrimExpr& expr) -> PrimExpr {
auto it = remap_.find(expr);
@@ -247,24 +202,9 @@ class RenewDefMutator : public StmtExprMutator {
return new_buffer;
}
- IterVar VisitIterVar(const IterVar& iter_var) {
- auto it = remap_.find(iter_var);
- if (it != remap_.end()) {
- return Downcast<IterVar>((*it).second);
- }
- PrimExpr min = VisitExpr(iter_var->dom->min);
- PrimExpr extent = VisitExpr(iter_var->dom->extent);
- IterVar new_iter_var(Range(min, extent), ReDefineVar(iter_var->var),
iter_var->iter_type,
- iter_var->thread_tag);
- this->AddDefRemap(iter_var, new_iter_var);
- return new_iter_var;
- }
-
- Buffer VisitDeclOrRemapBuffer(const Buffer& buffer) {
+ Buffer UseOrRemapBuffer(const Buffer& buffer) {
// If the buffer has been remapped, return the remapped buffer, otherwise,
- // return the declared one.
- // Due to a recent PR, we can allow undefined buffer appearing in
BufferLoad/Store. We need
- // to remap them but will not create new var
+ // remap it without creating new var definitions.
auto it = remap_.find(buffer);
if (it != remap_.end()) {
return Downcast<Buffer>((*it).second);
@@ -286,8 +226,21 @@ class RenewDefMutator : public StmtExprMutator {
return new_buffer;
}
+ IterVar VisitIterVar(const IterVar& iter_var) {
+ auto it = remap_.find(iter_var);
+ if (it != remap_.end()) {
+ return Downcast<IterVar>((*it).second);
+ }
+ PrimExpr min = VisitExpr(iter_var->dom->min);
+ PrimExpr extent = VisitExpr(iter_var->dom->extent);
+ IterVar new_iter_var(Range(min, extent), ReDefineVar(iter_var->var),
iter_var->iter_type,
+ iter_var->thread_tag);
+ this->AddDefRemap(iter_var, new_iter_var);
+ return new_iter_var;
+ }
+
MatchBufferRegion VisitMatchBuffer(const MatchBufferRegion& match_buffer) {
- Buffer buffer = VisitBuffer(match_buffer->buffer, /*define=*/true);
+ Buffer buffer = DefineBuffer(match_buffer->buffer);
BufferRegion region = VisitBufferRegion(match_buffer->source);
return MatchBufferRegion(std::move(buffer), std::move(region));
}
@@ -303,7 +256,7 @@ class RenewDefMutator : public StmtExprMutator {
}
BufferRegion VisitBufferRegion(const BufferRegion& buffer_region) {
- Buffer buffer = VisitBuffer(buffer_region->buffer);
+ Buffer buffer = UseOrRemapBuffer(buffer_region->buffer);
ffi::Array<Range> region = buffer_region->region.Map(
std::bind(&RenewDefMutator::VisitRange, this, std::placeholders::_1));
if (buffer.same_as(buffer_region->buffer) &&
region.same_as(buffer_region->region)) {
diff --git a/src/tir/analysis/var_use_def_analysis.cc
b/src/tir/analysis/var_use_def_analysis.cc
index 9c4d26b58c..b2236e28ce 100644
--- a/src/tir/analysis/var_use_def_analysis.cc
+++ b/src/tir/analysis/var_use_def_analysis.cc
@@ -64,21 +64,8 @@ void VarUseDefAnalyzer::VisitStmt_(const ForNode* op) {
StmtExprVisitor::VisitStmt_(op);
}
-void VarUseDefAnalyzer::VisitStmt_(const DeclBufferNode* op) {
- this->HandleDef(op->buffer);
- StmtExprVisitor::VisitStmt_(op);
-}
-
void VarUseDefAnalyzer::VisitStmt_(const AllocBufferNode* op) {
- // AllocBuffer both allocates the data variable and declares the buffer,
- // so we must define buffer->data before the buffer itself.
- this->HandleDef(op->buffer->data);
- this->HandleDef(op->buffer);
- StmtExprVisitor::VisitStmt_(op);
-}
-
-void VarUseDefAnalyzer::VisitStmt_(const BufferStoreNode* op) {
- HandleUse(op->buffer);
+ // VisitBufferDef (called by base) defines buffer->data and the buffer
itself.
StmtExprVisitor::VisitStmt_(op);
}
@@ -113,9 +100,29 @@ void VarUseDefAnalyzer::VisitExpr_(const ReduceNode* op) {
StmtExprVisitor::VisitExpr_(op);
}
-void VarUseDefAnalyzer::VisitExpr_(const BufferLoadNode* op) {
- HandleUse(op->buffer);
- StmtExprVisitor::VisitExpr_(op);
+void VarUseDefAnalyzer::VisitBufferDef(const Buffer& buffer, bool alloc_data) {
+ if (alloc_data) {
+ // AllocBuffer / SBlock: data is a new allocation — define it.
+ if (!use_count_.count(buffer->data.get())) {
+ HandleDef(buffer->data);
+ }
+ } else {
+ // DeclBuffer: data references an existing variable — use it.
+ HandleUse(buffer->data);
+ }
+ HandleDef(buffer);
+ // Visit shape/strides/elem_offset as uses of vars from the enclosing scope.
+ for (const auto& e : buffer->shape) this->VisitExpr(e);
+ for (const auto& e : buffer->strides) this->VisitExpr(e);
+ this->VisitExpr(buffer->elem_offset);
+}
+
+void VarUseDefAnalyzer::VisitBufferUse(const Buffer& buffer) {
+ HandleUse(buffer);
+ // Buffer data pointer must be tracked as a use — the use site
+ // reads/writes through this pointer. Without this, UndefinedVars
+ // misses data vars for buffers whose DeclBuffer is outside the scope.
+ HandleUse(buffer->data);
}
void VarUseDefAnalyzer::VisitBuffer(const Buffer& buffer) {
@@ -162,8 +169,8 @@ void VarUseDefAnalyzer::HandleDef(const Buffer& buf) {
<< "buffer " << ptr->name << " has been used before definition!";
buffer_use_count_[ptr] = 0;
buffer_def_count_[ptr] = 1;
-
- VisitBuffer(buf);
+ // Buffer fields (data, shape, strides) are visited by the caller
+ // (VisitBufferDef) via the base class, not here.
}
void VarUseDefAnalyzer::HandleUse(const Buffer& buf) {
@@ -177,8 +184,9 @@ void VarUseDefAnalyzer::HandleUse(const Buffer& buf) {
undefined_buffers_.push_back(ffi::GetRef<Buffer>(ptr));
buffer_use_count_[ptr] = -1;
}
-
- VisitBuffer(buf);
+ // Buffer fields (shape, strides, data) are visited at the definition
+ // site via VisitBufferDef. Do not re-visit them at use sites, as the
+ // buffer's shape variables may not be in scope at the point of use.
}
ffi::Array<Var> UndefinedVars(const Stmt& stmt, const ffi::Array<Var>& args) {
diff --git a/src/tir/analysis/var_use_def_analysis.h
b/src/tir/analysis/var_use_def_analysis.h
index 1c089c048f..7196fb2e8f 100644
--- a/src/tir/analysis/var_use_def_analysis.h
+++ b/src/tir/analysis/var_use_def_analysis.h
@@ -61,19 +61,19 @@ class VarUseDefAnalyzer : public StmtExprVisitor {
void VisitStmt_(const ForNode* op) final;
- void VisitStmt_(const DeclBufferNode* op) final;
-
void VisitStmt_(const AllocBufferNode* op) final;
- void VisitStmt_(const BufferStoreNode* op) final;
-
void VisitExpr_(const LetNode* op) final;
void VisitExpr_(const VarNode* op) final;
void VisitExpr_(const ReduceNode* op) final;
- void VisitExpr_(const BufferLoadNode* op) final;
+ // Piggyback on base class VisitBufferDef/VisitBufferUse to handle buffer
+ // def/use tracking. Base class calls these from AllocBuffer, DeclBuffer,
+ // BufferStore, BufferLoad, and SBlock visitors.
+ void VisitBufferDef(const Buffer& buffer, bool alloc_data) final;
+ void VisitBufferUse(const Buffer& buffer) final;
void HandleDef(const Var& v);
void HandleUse(const Var& v);
diff --git a/src/tir/analysis/verify_well_formed.cc
b/src/tir/analysis/verify_well_formed.cc
index 0ff363a547..5d8a2d5778 100644
--- a/src/tir/analysis/verify_well_formed.cc
+++ b/src/tir/analysis/verify_well_formed.cc
@@ -348,7 +348,7 @@ class UndefinedBufferVerifier : public
Verifier<UndefinedBufferVerifier> {
previously_defined_.insert({buffer, path});
}
- void Visit(const Buffer& buffer, AccessPath path) override {
+ void VisitBufferUse(const Buffer& buffer, AccessPath path) override {
bool is_declared = currently_defined_.count(buffer);
bool was_declared = previously_defined_.count(buffer);
@@ -363,8 +363,8 @@ class UndefinedBufferVerifier : public
Verifier<UndefinedBufferVerifier> {
Verify(false) << "TIR is ill-formed: buffer " << buffer->name << " is
used at " << path
<< " without a prior DeclBuffer or other declaration.";
}
- // Still visit the buffer's internal vars so variable usage is tracked.
- Verifier::Visit(buffer, path);
+ // Buffer fields are visited at definition site (EnterDef), not here.
+ Verifier::VisitBufferUse(buffer, path);
}
// Buffers defined in the currently-visited scope.
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index c4d83c2236..f7c508dac9 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -277,22 +277,16 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const
AttrStmtNode* op) {
return DataTypeLegalizer::VisitStmt_(op);
}
-Stmt IndexDataTypeRewriter::VisitStmt_(const AllocBufferNode* op) {
- Buffer new_buffer = VisitBuffer(op->buffer);
- AllocBuffer alloc_buffer =
Downcast<AllocBuffer>(StmtExprMutator::VisitStmt_(op));
- if (!new_buffer.same_as(op->buffer)) {
- alloc_buffer.CopyOnWrite()->buffer = new_buffer;
- }
- return alloc_buffer;
+Buffer IndexDataTypeRewriter::VisitBufferDef(const Buffer& buffer, bool
alloc_data) {
+ bool is_enabled = is_enabled_;
+ is_enabled_ = true;
+ Buffer new_buf = StmtMutator::VisitBufferDef(buffer, alloc_data);
+ is_enabled_ = is_enabled;
+ return new_buf;
}
-Stmt IndexDataTypeRewriter::VisitStmt_(const DeclBufferNode* op) {
- Buffer new_buffer = VisitBuffer(op->buffer);
- DeclBuffer decl_buffer =
Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
- if (!new_buffer.same_as(op->buffer)) {
- decl_buffer.CopyOnWrite()->buffer = new_buffer;
- }
- return decl_buffer;
+Buffer IndexDataTypeRewriter::VisitBufferUse(const Buffer& buffer) {
+ return StmtMutator::VisitBufferUse(buffer);
}
Stmt IndexDataTypeRewriter::VisitStmt_(const SBlockRealizeNode* op) {
@@ -322,11 +316,11 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const
SBlockRealizeNode* op) {
}
Stmt IndexDataTypeRewriter::VisitStmt_(const SBlockNode* op) {
- ffi::Array<Buffer> new_alloc_buffers =
- op->alloc_buffers.Map([this](const Buffer& buffer) { return
this->VisitBuffer(buffer); });
+ ffi::Array<Buffer> new_alloc_buffers = op->alloc_buffers.Map(
+ [this](const Buffer& buffer) { return this->VisitBufferDef(buffer,
/*alloc_data=*/true); });
ffi::Array<MatchBufferRegion> new_match_buffers =
op->match_buffers.Map([this](const MatchBufferRegion&
match_buffer_region) {
- Buffer new_buffer = this->VisitBuffer(match_buffer_region->buffer);
+ Buffer new_buffer = this->VisitBufferDef(match_buffer_region->buffer,
/*alloc_data=*/true);
BufferRegion new_buffer_region =
this->VisitBufferRegion(match_buffer_region->source);
if (!new_buffer.same_as(match_buffer_region->buffer) ||
!new_buffer_region.same_as(match_buffer_region->source)) {
@@ -378,7 +372,7 @@ ffi::Map<ffi::String, ffi::Any>
IndexDataTypeRewriter::VisitBlockAnnotations(
}
if (obj.as<BufferNode>()) {
Buffer buffer = Downcast<Buffer>(obj);
- if (Buffer new_buffer = GetRemappedBuffer(buffer);
!new_buffer.same_as(buffer)) {
+ if (Buffer new_buffer = VisitBufferUse(buffer);
!new_buffer.same_as(buffer)) {
return new_buffer;
}
} else if (obj.as<ffi::ArrayObj>()) {
@@ -397,13 +391,6 @@ ffi::Map<ffi::String, ffi::Any>
IndexDataTypeRewriter::VisitBlockAnnotations(
return new_annotations;
}
-Buffer IndexDataTypeRewriter::GetRemappedBuffer(const Buffer& buffer) {
- if (auto it = buffer_remap_.find(buffer); it != buffer_remap_.end()) {
- return (*it).second;
- }
- return buffer;
-}
-
IterVar IndexDataTypeRewriter::VisitIterVar(const IterVar& iter_var) {
bool is_enabled = is_enabled_;
is_enabled_ = true;
@@ -422,33 +409,8 @@ IterVar IndexDataTypeRewriter::VisitIterVar(const IterVar&
iter_var) {
return iter_var;
}
-Buffer IndexDataTypeRewriter::VisitBuffer(const Buffer& buffer) {
- bool is_enabled = is_enabled_;
-
- is_enabled_ = true;
- ffi::Array<PrimExpr> new_shape =
- buffer->shape.Map([&](const PrimExpr& e) { return this->VisitExpr(e); });
- ffi::Array<PrimExpr> new_strides =
- buffer->strides.Map([&](const PrimExpr& e) { return this->VisitExpr(e);
});
- auto new_elem_offset = VisitExpr(buffer->elem_offset);
- is_enabled_ = is_enabled;
-
- if (!buffer->shape.same_as(new_shape) ||
!buffer->strides.same_as(new_strides) ||
- !buffer->elem_offset.same_as(new_elem_offset)) {
- Buffer new_buffer = buffer;
- BufferNode* new_buffer_node = new_buffer.CopyOnWrite();
- new_buffer_node->shape = std::move(new_shape);
- new_buffer_node->strides = std::move(new_strides);
- new_buffer_node->elem_offset = std::move(new_elem_offset);
- buffer_remap_.Set(buffer, new_buffer);
- return new_buffer;
- } else {
- return buffer;
- }
-}
-
BufferRegion IndexDataTypeRewriter::VisitBufferRegion(const BufferRegion&
buffer_region) {
- Buffer remapped_buffer = GetRemappedBuffer(buffer_region->buffer);
+ Buffer remapped_buffer = VisitBufferUse(buffer_region->buffer);
bool is_enabled = is_enabled_;
is_enabled_ = true;
@@ -468,7 +430,7 @@ BufferRegion IndexDataTypeRewriter::VisitBufferRegion(const
BufferRegion& buffer
Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) {
BufferStore store = ffi::GetRef<BufferStore>(op);
- Buffer new_buffer = GetRemappedBuffer(op->buffer);
+ Buffer new_buffer = VisitBufferUse(op->buffer);
auto value = this->VisitExpr(op->value);
if (new_buffer->dtype != value->dtype && value->dtype.is_scalar()) {
value = cast(new_buffer->dtype, value);
@@ -489,7 +451,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const
BufferStoreNode* op) {
PrimExpr IndexDataTypeRewriter::VisitExpr_(const BufferLoadNode* op) {
BufferLoad load = ffi::GetRef<BufferLoad>(op);
- Buffer new_buffer = GetRemappedBuffer(op->buffer);
+ Buffer new_buffer = VisitBufferUse(op->buffer);
auto indices = VisitIndices(op->indices);
if (!new_buffer.same_as(op->buffer) || !indices.same_as(op->indices)) {
@@ -642,7 +604,7 @@ PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) {
// start rewrite
ffi::Map<Var, Buffer> new_buffer_map = func->buffer_map;
for (const auto& [var, buffer] : func->buffer_map) {
- new_buffer_map.Set(var, VisitBuffer(buffer));
+ new_buffer_map.Set(var, VisitBufferDef(buffer, /*alloc_data=*/true));
}
// remap params
bool is_enabled = true;
diff --git a/src/tir/ir/data_type_rewriter.h b/src/tir/ir/data_type_rewriter.h
index 8196485273..e886777096 100644
--- a/src/tir/ir/data_type_rewriter.h
+++ b/src/tir/ir/data_type_rewriter.h
@@ -101,6 +101,8 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
using Parent::VisitExpr_;
using Parent::VisitStmt_;
+ Buffer VisitBufferDef(const Buffer& buffer, bool alloc_data) override;
+ Buffer VisitBufferUse(const Buffer& buffer) override;
Stmt VisitStmt_(const SBlockRealizeNode* op) override;
Stmt VisitStmt_(const SBlockNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
@@ -108,8 +110,6 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
ffi::Array<PrimExpr> VisitIndices(ffi::Array<PrimExpr> indices);
Stmt VisitStmt_(const IfThenElseNode* op) override;
- Stmt VisitStmt_(const DeclBufferNode* op) override;
- Stmt VisitStmt_(const AllocBufferNode* op) override;
Stmt VisitStmt_(const LetStmtNode* op) override;
PrimExpr VisitExpr_(const EQNode* op) override;
PrimExpr VisitExpr_(const NENode* op) override;
@@ -122,8 +122,6 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
Stmt VisitStmt_(const ForNode* op) override;
- Buffer VisitBuffer(const Buffer& buffer);
- Buffer GetRemappedBuffer(const Buffer& buffer);
ffi::Map<ffi::String, ffi::Any> VisitBlockAnnotations(
const ffi::Map<ffi::String, ffi::Any>& annotations);
BufferRegion VisitBufferRegion(const BufferRegion& region);
@@ -132,8 +130,6 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
bool is_enabled_{false};
// indicator of condition
bool is_condition_{false};
-
- ffi::Map<Buffer, Buffer> buffer_remap_;
};
/*!
diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc
index 6c4f4666ef..1ad0749711 100644
--- a/src/tir/ir/specialize.cc
+++ b/src/tir/ir/specialize.cc
@@ -177,35 +177,8 @@ class PrimFuncSpecializer : public StmtExprMutator {
return stmt;
}
- Stmt VisitStmt_(const BufferStoreNode* op) final {
- Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<BufferStoreNode>();
- TVM_FFI_ICHECK(op != nullptr);
-
- auto new_buf = GetNewBuffer(op->buffer);
- if (new_buf.same_as(op->buffer)) {
- return ffi::GetRef<BufferStore>(op);
- } else {
- auto n = CopyOnWrite(op);
- n->buffer = new_buf;
- return Stmt(n);
- }
- }
-
- PrimExpr VisitExpr_(const BufferLoadNode* op) final {
- PrimExpr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<BufferLoadNode>();
- TVM_FFI_ICHECK(op != nullptr);
-
- auto new_buf = GetNewBuffer(op->buffer);
- if (new_buf.same_as(op->buffer)) {
- return ffi::GetRef<BufferLoad>(op);
- } else {
- auto n = ffi::make_object<BufferLoadNode>(*op);
- n->buffer = new_buf;
- return PrimExpr(n);
- }
- }
+ // Override VisitBufferUse to use our own buffer_map_ instead of base class
field visiting.
+ Buffer VisitBufferUse(const Buffer& buffer) final { return
GetNewBuffer(buffer); }
PrimExpr VisitExpr_(const VarNode* op) final {
auto it = var_map_.find(ffi::GetRef<Var>(op));
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index ce26337e15..ff79c374db 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -57,11 +57,35 @@ void StmtVisitor::VisitStmt_(const WhileNode* op) {
this->VisitStmt(op->body);
}
-void StmtVisitor::VisitStmt_(const AllocBufferNode* op) {
this->VisitStmt(op->body); }
+void StmtVisitor::VisitBufferDef(const Buffer& buffer, bool alloc_data) {
+ for (const auto& e : buffer->shape) this->VisitExpr(e);
+ for (const auto& e : buffer->strides) this->VisitExpr(e);
+ this->VisitExpr(buffer->elem_offset);
+}
+
+// Default VisitBufferUse is empty: buffer fields (shape, strides, elem_offset)
+// are visited at the definition site (VisitBufferDef) and should not be
+// re-visited at each use site, as the use site may be in a different scope
+// where the buffer's shape variables are not defined.
+void StmtVisitor::VisitBufferUse(const Buffer& buffer) {}
+
+void StmtExprVisitor::VisitExpr_(const BufferLoadNode* op) {
+ this->VisitBufferUse(op->buffer);
+ ExprVisitor::VisitExpr_(op);
+}
-void StmtVisitor::VisitStmt_(const DeclBufferNode* op) {
this->VisitStmt(op->body); }
+void StmtVisitor::VisitStmt_(const AllocBufferNode* op) {
+ this->VisitBufferDef(op->buffer, /*alloc_data=*/true);
+ this->VisitStmt(op->body);
+}
+
+void StmtVisitor::VisitStmt_(const DeclBufferNode* op) {
+ this->VisitBufferDef(op->buffer, /*alloc_data=*/false);
+ this->VisitStmt(op->body);
+}
void StmtVisitor::VisitStmt_(const BufferStoreNode* op) {
+ this->VisitBufferUse(op->buffer);
this->VisitExpr(op->value);
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
@@ -88,6 +112,7 @@ void StmtVisitor::VisitStmt_(const EvaluateNode* op) {
this->VisitExpr(op->value
void StmtVisitor::VisitStmt_(const SBlockNode* op) {
auto fvisit_buffer_region = [this](const BufferRegion& s) {
+ this->VisitBufferUse(s->buffer);
for (const auto& range : s->region) {
this->VisitExpr(range->min);
this->VisitExpr(range->extent);
@@ -97,10 +122,13 @@ void StmtVisitor::VisitStmt_(const SBlockNode* op) {
this->VisitExpr(iter_var->dom->min);
this->VisitExpr(iter_var->dom->extent);
});
+ VisitArray(op->alloc_buffers,
+ [this](const Buffer& buf) { this->VisitBufferDef(buf,
/*alloc_data=*/true); });
VisitArray(op->reads, fvisit_buffer_region);
VisitArray(op->writes, fvisit_buffer_region);
VisitArray(op->match_buffers,
- [fvisit_buffer_region](const MatchBufferRegion&
match_buffer_region) {
+ [this, &fvisit_buffer_region](const MatchBufferRegion&
match_buffer_region) {
+ this->VisitBufferDef(match_buffer_region->buffer,
/*alloc_data=*/true);
fvisit_buffer_region(match_buffer_region->source);
});
if (op->init.defined()) {
@@ -191,11 +219,12 @@ class StmtMutator::Internal {
static ffi::Array<BufferRegion> Mutate(StmtMutator* self, const
ffi::Array<BufferRegion>& arr) {
auto fmutate = [self](const BufferRegion& buffer_region) {
+ Buffer new_buf = self->VisitBufferUse(buffer_region->buffer);
ffi::Array<Range> region = Mutate(self, buffer_region->region);
- if (region.same_as(buffer_region->region)) {
+ if (new_buf.same_as(buffer_region->buffer) &&
region.same_as(buffer_region->region)) {
return buffer_region;
} else {
- return BufferRegion(buffer_region->buffer, region);
+ return BufferRegion(std::move(new_buf), std::move(region));
}
};
return MutateArray(self, arr, fmutate);
@@ -204,12 +233,16 @@ class StmtMutator::Internal {
static ffi::Array<MatchBufferRegion> Mutate(StmtMutator* self,
const
ffi::Array<MatchBufferRegion>& arr) {
auto fmutate = [self](const MatchBufferRegion& match_buffer_region) {
+ Buffer new_buf = self->VisitBufferDef(match_buffer_region->buffer,
/*alloc_data=*/true);
+ Buffer new_source_buf =
self->VisitBufferUse(match_buffer_region->source->buffer);
ffi::Array<Range> region = Mutate(self,
match_buffer_region->source->region);
- if (region.same_as(match_buffer_region->source->region)) {
+ if (new_buf.same_as(match_buffer_region->buffer) &&
+ new_source_buf.same_as(match_buffer_region->source->buffer) &&
+ region.same_as(match_buffer_region->source->region)) {
return match_buffer_region;
} else {
- return MatchBufferRegion(match_buffer_region->buffer,
-
BufferRegion(match_buffer_region->source->buffer, region));
+ return MatchBufferRegion(std::move(new_buf),
+ BufferRegion(std::move(new_source_buf),
std::move(region)));
}
};
return MutateArray(self, arr, fmutate);
@@ -276,25 +309,74 @@ Stmt StmtMutator::VisitStmt_(const WhileNode* op) {
}
}
+Buffer StmtMutator::VisitBufferDef(const Buffer& buffer, bool alloc_data) {
+ if (auto it = buffer_remap_.find(buffer); it != buffer_remap_.end()) {
+ return (*it).second;
+ }
+
+ // Visit expression fields (shape, strides, elem_offset) but NOT data.
+ // data is a Var definition owned by this buffer, not an expression use.
+ // Subclasses that need to remap data (e.g., IRSubstitute) can override.
+ auto shape = buffer->shape.Map([this](const PrimExpr& e) { return
this->VisitExpr(e); });
+ auto strides = buffer->strides.Map([this](const PrimExpr& e) { return
this->VisitExpr(e); });
+ PrimExpr elem_offset = this->VisitExpr(buffer->elem_offset);
+
+ if (shape.same_as(buffer->shape) && strides.same_as(buffer->strides) &&
+ elem_offset.same_as(buffer->elem_offset)) {
+ return buffer;
+ }
+ Buffer new_buf = buffer;
+ auto* n = new_buf.CopyOnWrite();
+ n->shape = std::move(shape);
+ n->strides = std::move(strides);
+ n->elem_offset = std::move(elem_offset);
+ buffer_remap_.Set(buffer, new_buf);
+ return new_buf;
+}
+
+Buffer StmtMutator::VisitBufferUse(const Buffer& buffer) {
+ if (auto it = buffer_remap_.find(buffer); it != buffer_remap_.end()) {
+ return (*it).second;
+ }
+ return buffer;
+}
+
+PrimExpr StmtExprMutator::VisitExpr_(const BufferLoadNode* op) {
+ Buffer new_buf = this->VisitBufferUse(op->buffer);
+ PrimExpr expr = ExprMutator::VisitExpr_(op);
+ op = expr.as<BufferLoadNode>();
+ TVM_FFI_ICHECK(op != nullptr);
+ if (!new_buf.same_as(op->buffer)) {
+ auto n = ffi::make_object<BufferLoadNode>(*op);
+ n->buffer = std::move(new_buf);
+ return PrimExpr(n);
+ }
+ return expr;
+}
+
Stmt StmtMutator::VisitStmt_(const AllocBufferNode* op) {
+ Buffer new_buf = this->VisitBufferDef(op->buffer, /*alloc_data=*/true);
Stmt body = this->VisitStmt(op->body);
- if (body.same_as(op->body)) {
+ if (new_buf.same_as(op->buffer) && body.same_as(op->body)) {
return ffi::GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
+ n->buffer = std::move(new_buf);
n->body = std::move(body);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) {
+ Buffer new_buf = this->VisitBufferDef(op->buffer, /*alloc_data=*/false);
Stmt body = this->VisitStmt(op->body);
- if (body.same_as(op->body)) {
+ if (new_buf.same_as(op->buffer) && body.same_as(op->body)) {
return ffi::GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
+ n->buffer = std::move(new_buf);
n->body = std::move(body);
return Stmt(n);
}
@@ -320,13 +402,15 @@ Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) {
}
Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) {
+ Buffer new_buf = this->VisitBufferUse(op->buffer);
PrimExpr value = this->VisitExpr(op->value);
ffi::Array<PrimExpr> indices = Internal::Mutate(this, op->indices);
- if (value.same_as(op->value) && indices.same_as(op->indices)) {
+ if (new_buf.same_as(op->buffer) && value.same_as(op->value) &&
indices.same_as(op->indices)) {
return ffi::GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
+ n->buffer = std::move(new_buf);
n->value = std::move(value);
n->indices = std::move(indices);
return Stmt(n);
@@ -419,6 +503,9 @@ Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) {
Stmt StmtMutator::VisitStmt_(const SBlockNode* op) {
ffi::Array<IterVar> iter_vars = Internal::Mutate(this, op->iter_vars);
+ ffi::Array<Buffer> alloc_buffers = Internal::MutateArray(
+ this, op->alloc_buffers,
+ [this](const Buffer& buf) { return this->VisitBufferDef(buf,
/*alloc_data=*/true); });
ffi::Array<BufferRegion> reads = Internal::Mutate(this, op->reads);
ffi::Array<BufferRegion> writes = Internal::Mutate(this, op->writes);
ffi::Array<MatchBufferRegion> match_buffers = Internal::Mutate(this,
op->match_buffers);
@@ -427,13 +514,14 @@ Stmt StmtMutator::VisitStmt_(const SBlockNode* op) {
init = VisitStmt(op->init.value());
}
Stmt body = VisitStmt(op->body);
- if (iter_vars.same_as(op->iter_vars) && reads.same_as(op->reads) &&
writes.same_as(op->writes) &&
- body.same_as(op->body) && init.same_as(op->init) &&
- match_buffers.same_as(op->match_buffers)) {
+ if (iter_vars.same_as(op->iter_vars) &&
alloc_buffers.same_as(op->alloc_buffers) &&
+ reads.same_as(op->reads) && writes.same_as(op->writes) &&
body.same_as(op->body) &&
+ init.same_as(op->init) && match_buffers.same_as(op->match_buffers)) {
return ffi::GetRef<SBlock>(op);
} else {
auto n = CopyOnWrite(op);
n->iter_vars = std::move(iter_vars);
+ n->alloc_buffers = std::move(alloc_buffers);
n->reads = std::move(reads);
n->writes = std::move(writes);
n->body = std::move(body);
@@ -477,6 +565,9 @@ class IRApplyVisit : public StmtExprVisitor {
f_(node);
}
+ void VisitBufferDef(const Buffer& buffer, bool alloc_data) override {}
+ void VisitBufferUse(const Buffer& buffer) override {}
+
private:
std::function<void(const ObjectRef&)> f_;
std::unordered_set<const Object*> visited_;
@@ -568,68 +659,23 @@ class IRSubstitute : public StmtExprMutator {
return var;
}
- PrimExpr VisitExpr_(const BufferLoadNode* op) final {
- auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
- return VisitBufferAccess(std::move(node));
- }
-
- Stmt VisitStmt_(const BufferStoreNode* op) final {
- auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
- return VisitBufferAccess(std::move(node));
- }
-
- Stmt VisitStmt_(const AllocBufferNode* op) final {
- auto node = Downcast<AllocBuffer>(StmtExprMutator::VisitStmt_(op));
- return VisitBufferAccess(std::move(node));
- }
-
- Stmt VisitStmt_(const DeclBufferNode* op) final {
- auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
- return VisitBufferAccess(std::move(node));
- }
-
- template <typename Node>
- Node VisitBufferAccess(Node node) {
- Buffer new_buf = GetRemappedBuffer(node->buffer);
-
- if (!new_buf.same_as(node->buffer)) {
- auto writer = node.CopyOnWrite();
- writer->buffer = new_buf;
- }
-
- return node;
- }
-
- Buffer GetRemappedBuffer(Buffer buf) {
- auto key = buf.get();
- auto it = buf_remap_.find(key);
- if (it != buf_remap_.end()) {
- return it->second;
- }
-
- PrimExpr new_buffer_var_expr = VisitExpr(buf->data);
- TVM_FFI_ICHECK(new_buffer_var_expr->IsInstance<VarNode>())
- << "Buffer " << buf << " uses backing allocation " << buf->data
- << ", which was substituted into the expression " <<
new_buffer_var_expr << ". "
- << "However, this expression is of type " <<
new_buffer_var_expr->GetTypeKey()
+ // Override VisitBufferDef to also remap buffer->data (the backing
allocation var).
+ // The base class only visits shape/strides/elem_offset.
+ Buffer VisitBufferDef(const Buffer& buffer, bool alloc_data) final {
+ Buffer new_buf = StmtExprMutator::VisitBufferDef(buffer, alloc_data);
+ // Additionally handle data var substitution (base does not visit data).
+ PrimExpr new_data_expr = VisitExpr(new_buf->data);
+ TVM_FFI_ICHECK(new_data_expr->IsInstance<VarNode>())
+ << "Buffer " << new_buf << " uses backing allocation " << new_buf->data
+ << ", which was substituted into the expression " << new_data_expr
<< " and the backing allocation must be a tir::Var";
-
- Var buffer_var = Downcast<Var>(new_buffer_var_expr);
- auto elem_offset = VisitExpr(buf->elem_offset);
- auto shape = buf->shape.Map([this](const auto& expr) { return
VisitExpr(expr); });
- auto strides = buf->strides.Map([this](const auto& expr) { return
VisitExpr(expr); });
-
- if (!buffer_var.same_as(buf->data) ||
!elem_offset.same_as(buf->elem_offset) ||
- !shape.same_as(buf->shape) || !strides.same_as(buf->strides)) {
- auto writer = buf.CopyOnWrite();
- writer->data = buffer_var;
- writer->elem_offset = elem_offset;
- writer->shape = shape;
- writer->strides = strides;
+ Var data = Downcast<Var>(new_data_expr);
+ if (!data.same_as(new_buf->data)) {
+ auto* n = new_buf.CopyOnWrite();
+ n->data = std::move(data);
+ buffer_remap_.Set(buffer, new_buf);
}
-
- buf_remap_[key] = buf;
- return buf;
+ return new_buf;
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
@@ -647,15 +693,6 @@ class IRSubstitute : public StmtExprMutator {
private:
// Caller provided function that defines the variables to be remapped.
std::function<ffi::Optional<PrimExpr>(const Var&)> vmap_;
-
- /* \brief Generated map to track buffers being remapped.
- *
- * If a `Var BufferNode::data` is remapped, then all buffers
- * containing that data pointer should also be remapped. This map
- * is used to track buffer modifications, and ensure all instances
- * of a buffer are replaced by the same modified buffer object.
- */
- std::unordered_map<const BufferNode*, Buffer> buf_remap_;
};
Stmt Substitute(Stmt stmt, std::function<ffi::Optional<PrimExpr>(const Var&)>
vmap) {
@@ -726,45 +763,6 @@ class IRSubstituteWithDataTypeLegalization : public
DataTypeLegalizer {
return var;
}
- PrimExpr VisitExpr_(const BufferLoadNode* op) final {
- auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
- return VisitBufferAccess(std::move(node));
- }
-
- Stmt VisitStmt_(const BufferStoreNode* op) final {
- auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
- return VisitBufferAccess(std::move(node));
- }
-
- template <typename Node>
- Node VisitBufferAccess(Node node) {
- Buffer new_buf = GetRemappedBuffer(node->buffer);
-
- if (!new_buf.same_as(node->buffer)) {
- auto writer = node.CopyOnWrite();
- writer->buffer = new_buf;
- }
-
- return node;
- }
-
- Buffer GetRemappedBuffer(Buffer buf) {
- auto key = buf.get();
- auto it = buf_remap_.find(key);
- if (it != buf_remap_.end()) {
- return it->second;
- }
-
- auto new_buffer_var = vmap_(buf->data);
- if (new_buffer_var.defined() &&
!new_buffer_var.value().same_as(buf->data)) {
- auto writer = buf.CopyOnWrite();
- writer->data = Downcast<Var>(new_buffer_var);
- }
-
- buf_remap_[key] = buf;
- return buf;
- }
-
Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<AttrStmtNode>();
@@ -780,15 +778,6 @@ class IRSubstituteWithDataTypeLegalization : public
DataTypeLegalizer {
private:
// Caller provided function that defines the variables to be remapped.
std::function<ffi::Optional<PrimExpr>(const Var&)> vmap_;
-
- /* \brief Generated map to track buffers being remapped.
- *
- * If a `Var BufferNode::data` is remapped, then all buffers
- * containing that data pointer should also be remapped. This map
- * is used to track buffer modifications, and ensure all instances
- * of a buffer are replaced by the same modified buffer object.
- */
- std::unordered_map<const BufferNode*, Buffer> buf_remap_;
};
Stmt SubstituteWithDataTypeLegalization(Stmt stmt,
diff --git a/src/tir/ir/tir_visitor_with_path.cc
b/src/tir/ir/tir_visitor_with_path.cc
index 1cf888a542..5436e73d57 100644
--- a/src/tir/ir/tir_visitor_with_path.cc
+++ b/src/tir/ir/tir_visitor_with_path.cc
@@ -129,23 +129,26 @@ void TIRVisitorWithPath::ExitDef(const IterVar& iter_var,
AccessPath path) {
void TIRVisitorWithPath::EnterDef(const Buffer& buffer, AccessPath path) {
// Defining a buffer counts as using all parameters in the buffer
// (e.g. shape/strides).
- Visit(buffer->data, path->Attr("data"));
- Visit(buffer->shape, path->Attr("shape"));
- Visit(buffer->strides, path->Attr("strides"));
- Visit(buffer->elem_offset, path->Attr("elem_offset"));
+ VisitBufferDef(buffer, path);
}
void TIRVisitorWithPath::ExitDef(const Buffer& buffer, AccessPath path) {}
-void TIRVisitorWithPath::Visit(const Buffer& buffer, AccessPath path) {
- // Using a buffer *also* counts as using all parameters in the buffer.
+void TIRVisitorWithPath::VisitBufferDef(const Buffer& buffer, AccessPath path)
{
Visit(buffer->data, path->Attr("data"));
Visit(buffer->shape, path->Attr("shape"));
Visit(buffer->strides, path->Attr("strides"));
Visit(buffer->elem_offset, path->Attr("elem_offset"));
}
+// Default: buffer use sites do not re-visit buffer fields. Buffer fields
+// (shape, strides, elem_offset) are visited at the definition site via
+// VisitBufferDef/EnterDef. Re-visiting at use sites would require those
+// variables to be in scope at every use, which may not hold when buffers
+// are allocated in a different scope than where they are used.
+void TIRVisitorWithPath::VisitBufferUse(const Buffer& buffer, AccessPath path)
{}
+
void TIRVisitorWithPath::Visit(const BufferRegion& region, AccessPath path) {
- Visit(region->buffer, path->Attr("buffer"));
+ VisitBufferUse(region->buffer, path->Attr("buffer"));
Visit(region->region, path->Attr("region"));
}
@@ -225,7 +228,7 @@ void TIRVisitorWithPath::VisitStmt_(const DeclBufferNode*
op, AccessPath path) {
void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, AccessPath
path) {
Visit(op->value, path->Attr("value"));
- Visit(op->buffer, path->Attr("buffer"));
+ VisitBufferUse(op->buffer, path->Attr("buffer"));
Visit(op->indices, path->Attr("indices"));
}
@@ -308,7 +311,7 @@ void TIRVisitorWithPath::VisitExpr_(const SizeVarNode* op,
AccessPath path) {
}
void TIRVisitorWithPath::VisitExpr_(const BufferLoadNode* op, AccessPath path)
{
- Visit(op->buffer, path->Attr("buffer"));
+ VisitBufferUse(op->buffer, path->Attr("buffer"));
Visit(op->indices, path->Attr("indices"));
}
diff --git a/src/tir/ir/tir_visitor_with_path.h
b/src/tir/ir/tir_visitor_with_path.h
index 51e435b47d..f5189ae61c 100644
--- a/src/tir/ir/tir_visitor_with_path.h
+++ b/src/tir/ir/tir_visitor_with_path.h
@@ -53,12 +53,18 @@ class TIRVisitorWithPath
// Delegate to ExprFunctor::VisitStmt for Stmt, and any subclasses
inline void Visit(const Stmt& obj, ffi::reflection::AccessPath path) {
VisitStmt(obj, path); }
+ // Visit a buffer at a use site (BufferLoad, BufferStore, reads/writes).
+ // By default, does not re-visit buffer fields (shape, strides, elem_offset),
+ // as those are visited at the definition site via EnterDef.
+ virtual void VisitBufferUse(const Buffer& obj, ffi::reflection::AccessPath
path);
+ // Visit a buffer at a definition site. By default visits buffer fields.
+ virtual void VisitBufferDef(const Buffer& obj, ffi::reflection::AccessPath
path);
+
// Visitors for TIR constructs that are neither PrimExpr nor Stmt
virtual void Visit(const IRModule& obj, ffi::reflection::AccessPath path);
virtual void Visit(const PrimFunc& obj, ffi::reflection::AccessPath path);
virtual void Visit(const GlobalVar& obj, ffi::reflection::AccessPath path) {}
virtual void Visit(const Range& obj, ffi::reflection::AccessPath path);
- virtual void Visit(const Buffer& obj, ffi::reflection::AccessPath path);
virtual void Visit(const BufferRegion& obj, ffi::reflection::AccessPath
path);
virtual void Visit(const MatchBufferRegion& obj, ffi::reflection::AccessPath
path);
virtual void Visit(const IterVar& obj, ffi::reflection::AccessPath path);
diff --git a/src/tir/transform/ir_utils.cc b/src/tir/transform/ir_utils.cc
index b9d8cdaab7..9398c2561e 100644
--- a/src/tir/transform/ir_utils.cc
+++ b/src/tir/transform/ir_utils.cc
@@ -205,6 +205,16 @@ class IRConvertSSA final : public StmtExprMutator {
return func;
}
+ // Do not use the base VisitBufferDef for buffer remapping.
+ //
+ // IRConvertSSA has its own scoped buffer remapping via GetRemappedBuffer and
+ // buf_remap_, which handles SSA conversion of buffer data vars, shape,
strides,
+ // and elem_offset with proper scope tracking. The base
StmtMutator::VisitBufferDef
+ // would create a conflicting second remap (into base buffer_remap_) when
called
+ // from the default DeclBuffer/AllocBuffer handlers, producing buffers with
+ // undefined SSA-renamed variables.
+ Buffer VisitBufferDef(const Buffer& buffer, bool alloc_data) override {
return buffer; }
+
PrimExpr VisitExpr_(const VarNode* op) final { return
GetRemappedVar(ffi::GetRef<Var>(op)); }
PrimExpr VisitExpr_(const LetNode* op) final {
const Var& v = op->var;
diff --git a/src/tir/transform/simplify.cc b/src/tir/transform/simplify.cc
index f06e52f328..af0fc4cf47 100644
--- a/src/tir/transform/simplify.cc
+++ b/src/tir/transform/simplify.cc
@@ -182,6 +182,19 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
using Parent::VisitStmt;
using Parent::VisitStmt_;
+ // Do not simplify buffer definition fields (shape, strides, elem_offset).
+ //
+ // The simplifier's VisitExpr override calls analyzer_->Simplify() directly,
+ // bypassing the normal ExprMutator dispatch. This means BufferLoad
expressions
+ // inside values (e.g., BufferStore value) skip VisitExpr_(BufferLoadNode*)
and
+ // thus skip VisitBufferUse. If VisitBufferDef remaps buffers at DeclBuffer
sites,
+ // the BufferLoad use sites won't pick up the remap, causing
DeclBuffer/BufferLoad
+ // buffer identity divergence and well-formedness violations.
+ //
+ // Instead, we keep buffer definitions unchanged and rely on
used_in_buffer_def_
+ // to prevent inlining LetStmt vars that appear in buffer definitions.
+ Buffer VisitBufferDef(const Buffer& buffer, bool alloc_data) override {
return buffer; }
+
PrimExpr VisitExpr(const PrimExpr& expr) final {
if (config_->propagate_knowns_to_simplify_expressions) {
return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(),
analyzer_);
diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc
index b36420dfe2..3375be981c 100644
--- a/tests/cpp/ir_functor_test.cc
+++ b/tests/cpp/ir_functor_test.cc
@@ -157,7 +157,9 @@ TEST(IRF, StmtVisitor) {
return AllocBuffer(buf, body);
};
v(fmaketest());
- TVM_FFI_ICHECK_EQ(v.count, 1);
+ // AllocBuffer now visits buffer shape via VisitBufferDef.
+ // shape = {z, z} where z = x + 1, so x is visited twice from shape + once
from body = 3
+ TVM_FFI_ICHECK_EQ(v.count, 3);
{
// tests for block and block_realize
@@ -176,7 +178,10 @@ TEST(IRF, StmtVisitor) {
v.count = 0;
v(block_realize);
- TVM_FFI_ICHECK_EQ(v.count, 5);
+ // Old count was 5 (x visited in reads/writes/match_buffers range + init
body + body).
+ // VisitBufferDef now also visits AllocBuffer's buffer shape {x+1, x+1} in
both init and body.
+ // Each adds 2 VarNode visits (x in each shape element), so 5 + 4 = 9.
+ TVM_FFI_ICHECK_EQ(v.count, 9);
}
}
@@ -221,8 +226,9 @@ TEST(IRF, StmtMutator) {
auto* arrptr = arr.get();
arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
TVM_FFI_ICHECK(arr.get() == arrptr);
- // buffer is not mutated (AllocBuffer mutator only visits body)
- TVM_FFI_ICHECK(arr[0].as<AllocBufferNode>()->buffer.get() == bufptr);
+ // buffer IS mutated now (AllocBuffer mutator visits buffer shape via
VisitBufferDef)
+ // shape was {1, x+1}, mutator transforms x+1 -> x, so buffer changes
+ TVM_FFI_ICHECK(arr[0].as<AllocBufferNode>()->buffer.get() != bufptr);
// body is mutated: x+1 -> x
TVM_FFI_ICHECK(!arr[0].as<AllocBufferNode>()->body.same_as(bref));
TVM_FFI_ICHECK(arr[0].as<AllocBufferNode>()->body.as<EvaluateNode>()->value.same_as(x));
@@ -270,7 +276,8 @@ TEST(IRF, StmtMutator) {
body = v(std::move(body));
// the seq get flattened
TVM_FFI_ICHECK(body.as<SeqStmtNode>()->size() == 3);
-
TVM_FFI_ICHECK(body.as<SeqStmtNode>()->seq[0].as<AllocBufferNode>()->buffer.get()
== bufptr);
+ // buffer is now mutated (shape x+1 -> x via VisitBufferDef)
+
TVM_FFI_ICHECK(body.as<SeqStmtNode>()->seq[0].as<AllocBufferNode>()->buffer.get()
!= bufptr);
TVM_FFI_ICHECK(body.as<SeqStmtNode>()->seq[1].get() == ref2);
}
@@ -333,36 +340,43 @@ TEST(IRF, Substitute) {
using namespace tvm::tir;
DataType dtype = DataType::Float(32);
Var x("x", PointerType(PrimType(dtype), ""));
- auto fmaketest = [&]() {
- Buffer buffer{/*data=*/x,
+ Var n("n", DataType::Int(32));
+
+ auto fmakebuffer = [&]() {
+ return Buffer{/*data=*/x,
/*dtype=*/DataType::Float(32),
- /*shape=*/{},
+ /*shape=*/{n},
/*strides=*/{},
/*elem_offset=*/NullValue<PrimExpr>(),
/*name=*/"buf",
/*data_alignment=*/1,
/*offset_factor=*/1,
/*buffer_type=*/BufferType::kDefault};
- return BufferLoad(buffer, {});
};
{
- // test substitute buffer var
+ // test substitute buffer data var and shape var via DeclBuffer
Var y = x.copy_with_suffix("subst");
- BufferLoad buffer_load = fmaketest();
+ Var m("m", DataType::Int(32));
+ Buffer buffer = fmakebuffer();
+ Stmt store = BufferStore(buffer, FloatImm(dtype, 0),
{IntImm(DataType::Int(32), 0)});
+ Stmt decl = DeclBuffer(buffer, store);
auto f_subst = [&](const Var& var) -> ffi::Optional<PrimExpr> {
- if (var.same_as(x)) {
- return y;
- }
+ if (var.same_as(x)) return y;
+ if (var.same_as(n)) return m;
return std::nullopt;
};
- BufferLoad new_buffer_load = Downcast<BufferLoad>(Substitute(buffer_load,
f_subst));
- TVM_FFI_ICHECK(new_buffer_load->buffer->data.same_as(y));
+ Stmt new_decl = Substitute(decl, f_subst);
+ auto* decl_node = new_decl.as<DeclBufferNode>();
+ TVM_FFI_ICHECK(decl_node != nullptr);
+ TVM_FFI_ICHECK(decl_node->buffer->data.same_as(y));
+ TVM_FFI_ICHECK(decl_node->buffer->shape[0].same_as(m));
}
{
- // test identity substitution
- PrimExpr expr = fmaketest();
+ // test identity substitution on expression
+ Buffer buffer = fmakebuffer();
+ PrimExpr expr = BufferLoad(buffer, {IntImm(DataType::Int(32), 0)});
auto f_subst = [&](const Var& var) -> ffi::Optional<PrimExpr> { return
var; };
PrimExpr new_expr = Substitute(expr, f_subst);
// the expression is not changed
diff --git a/tests/python/tir-analysis/test_tir_analysis_undefined_vars.py
b/tests/python/tir-analysis/test_tir_analysis_undefined_vars.py
new file mode 100644
index 0000000000..3fd17830fd
--- /dev/null
+++ b/tests/python/tir-analysis/test_tir_analysis_undefined_vars.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for tir.analysis.undefined_vars (VarUseDefAnalyzer)."""
+
+import tvm
+import tvm.testing
+from tvm import tir
+
+
+def test_decl_buffer_data_is_use():
+ """DeclBuffer's data var should be reported as undefined (USE), not
defined.
+
+ When UndefinedVars encounters a DeclBuffer, the data pointer references
+ an existing variable from the enclosing scope. It must appear in the
+ undefined list so that callers (e.g., CreateComputeScope) capture it.
+ """
+ n = tir.SizeVar("n", "int32")
+ from tvm.ir import PointerType, PrimType
+
+ data_ptr = tir.Var("buf_data", PointerType(PrimType("float32")))
+ buf = tir.decl_buffer((n,), "float32", "buf", data=data_ptr)
+
+ body = tir.Evaluate(tir.BufferLoad(buf, [0]))
+ decl = tir.DeclBuffer(buf, body)
+
+ undef = tvm.tir.analysis.undefined_vars(decl, [])
+ undef_names = {v.name for v in undef}
+ # data_ptr must be undefined (it comes from outside the DeclBuffer)
+ assert "buf_data" in undef_names, f"Expected buf_data in undefined vars,
got {undef_names}"
+
+
+def test_decl_buffer_elem_offset_is_use():
+ """DeclBuffer's elem_offset var should be reported as undefined (USE).
+
+ After FlattenBuffer, DeclBuffer nodes carry elem_offset vars from
+ match_buffer entries. These must appear in the undefined list.
+ """
+ from tvm.ir import PointerType, PrimType
+
+ n = tir.SizeVar("n", "int32")
+ data_ptr = tir.Var("buf_data", PointerType(PrimType("float32")))
+ elem_off = tir.Var("buf_elem_offset", "int32")
+ buf = tir.decl_buffer((n,), "float32", "buf", data=data_ptr,
elem_offset=elem_off)
+
+ body = tir.Evaluate(tir.BufferLoad(buf, [0]))
+ decl = tir.DeclBuffer(buf, body)
+
+ undef = tvm.tir.analysis.undefined_vars(decl, [])
+ undef_names = {v.name for v in undef}
+ assert "buf_data" in undef_names, f"Expected buf_data in undefined vars,
got {undef_names}"
+ assert "buf_elem_offset" in undef_names, (
+ f"Expected buf_elem_offset in undefined vars, got {undef_names}"
+ )
+
+
+def test_alloc_buffer_data_is_def():
+ """AllocBuffer's data var should NOT be reported as undefined (it's a DEF).
+
+ AllocBuffer allocates new storage — the data pointer is a new definition,
+ not a reference to an external variable.
+ """
+ n = tir.SizeVar("n", "int32")
+ buf = tir.decl_buffer((n,), "float32", "buf")
+
+ body = tir.Evaluate(tir.BufferLoad(buf, [0]))
+ alloc = tir.AllocBuffer(buf, body)
+
+ undef = tvm.tir.analysis.undefined_vars(alloc, [])
+ undef_names = {v.name for v in undef}
+ # data should NOT be undefined — AllocBuffer defines it
+ assert buf.data.name not in undef_names, (
+ f"AllocBuffer data should be defined, but found {buf.data.name} in
{undef_names}"
+ )
+ # shape var n should be undefined (comes from enclosing scope)
+ assert "n" in undef_names, f"Expected shape var 'n' in undefined vars, got
{undef_names}"
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/tir-transform/test_tir_transform_simplify.py
b/tests/python/tir-transform/test_tir_transform_simplify.py
index 3f73ed8e16..46e094acfa 100644
--- a/tests/python/tir-transform/test_tir_transform_simplify.py
+++ b/tests/python/tir-transform/test_tir_transform_simplify.py
@@ -1819,7 +1819,7 @@ def test_simplify_trivial_let_buffer_var():
def test_simplify_trivial_let_elem_offset():
- """A LetStmt used in a buffer definition should be retained"""
+ """A LetStmt used in a buffer definition should be retained, buffer fields
unchanged"""
@T.prim_func(private=True)
def before(A_ptr: T.handle("float32"), A_offset: T.int32):
@@ -1827,14 +1827,18 @@ def test_simplify_trivial_let_elem_offset():
A = T.decl_buffer(1, "float32", elem_offset=A_offset_redef, data=A_ptr)
A[0] = 42.0
- expected = before
+ @T.prim_func(private=True)
+ def expected(A_ptr: T.handle("float32"), A_offset: T.int32):
+ A_offset_redef = A_offset
+ A = T.decl_buffer(1, "float32", elem_offset=A_offset_redef, data=A_ptr)
+ A[0] = 42.0
after = _apply_simplify(before)
tvm.ir.assert_structural_equal(after, expected)
def test_simplify_trivial_let_shape():
- """A LetStmt used in a buffer definition should be retained"""
+ """A LetStmt used in a buffer definition should be retained, buffer fields
unchanged"""
@T.prim_func(private=True)
def before(A_ptr: T.handle("float32"), A_size: T.int32):
@@ -1842,14 +1846,18 @@ def test_simplify_trivial_let_shape():
A = T.decl_buffer([A_size_redef], "float32", data=A_ptr)
A[0] = 42.0
- expected = before
+ @T.prim_func(private=True)
+ def expected(A_ptr: T.handle("float32"), A_size: T.int32):
+ A_size_redef = A_size
+ A = T.decl_buffer([A_size_redef], "float32", data=A_ptr)
+ A[0] = 42.0
after = _apply_simplify(before)
tvm.ir.assert_structural_equal(after, expected)
def test_simplify_trivial_let_stride():
- """A LetStmt used in a buffer definition should be retained"""
+ """A LetStmt used in a buffer definition should be retained, buffer fields
unchanged"""
@T.prim_func(private=True)
def before(A_ptr: T.handle("float32"), A_stride: T.int32):
@@ -1857,12 +1865,37 @@ def test_simplify_trivial_let_stride():
A = T.decl_buffer(1, "float32", strides=[A_stride_redef], data=A_ptr)
A[0] = 42.0
- expected = before
+ @T.prim_func(private=True)
+ def expected(A_ptr: T.handle("float32"), A_stride: T.int32):
+ A_stride_redef = A_stride
+ A = T.decl_buffer(1, "float32", strides=[A_stride_redef], data=A_ptr)
+ A[0] = 42.0
after = _apply_simplify(before)
tvm.ir.assert_structural_equal(after, expected)
+def test_simplify_buffer_identity_well_formed():
+ """Regression: Simplify must not diverge buffer identity between
DeclBuffer and BufferLoad.
+
+ The simplifier's VisitExpr calls analyzer_->Simplify() directly, bypassing
+ normal ExprMutator dispatch. If VisitBufferDef remaps a buffer at a
DeclBuffer
+ site (e.g. inlining n_val -> n in the shape), BufferLoad inside a
BufferStore
+ value would NOT pick up the remap because VisitBufferUse is never called.
+ This causes DeclBuffer/BufferLoad buffer identity divergence.
+ """
+
+ @T.prim_func(private=True)
+ def before(A_ptr: T.handle("float32"), B_ptr: T.handle("float32"), n:
T.int32):
+ n_val = n
+ A = T.decl_buffer([n_val], "float32", data=A_ptr)
+ B = T.decl_buffer([n_val], "float32", data=B_ptr)
+ B[0] = A[0]
+
+ after = _apply_simplify(before)
+ tvm.tir.analysis.verify_well_formed(after)
+
+
def test_buffer_shape_constraint():
@I.ir_module(check_well_formed=False)
class Before: