This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 969fad363b [TIR] Add VisitBufferDef/VisitBufferUse to base 
StmtVisitor/StmtMutator (#18873)
969fad363b is described below

commit 969fad363be13375a4f4ecbc6a5fb2030e8e3f41
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Mar 4 22:06:03 2026 -0500

    [TIR] Add VisitBufferDef/VisitBufferUse to base StmtVisitor/StmtMutator 
(#18873)
---
 include/tvm/tir/stmt_functor.h                     |  42 +++-
 src/s_tir/transform/renew_defs.cc                  | 105 +++------
 src/tir/analysis/var_use_def_analysis.cc           |  50 ++--
 src/tir/analysis/var_use_def_analysis.h            |  10 +-
 src/tir/analysis/verify_well_formed.cc             |   6 +-
 src/tir/ir/data_type_rewriter.cc                   |  70 ++----
 src/tir/ir/data_type_rewriter.h                    |   8 +-
 src/tir/ir/specialize.cc                           |  31 +--
 src/tir/ir/stmt_functor.cc                         | 251 ++++++++++-----------
 src/tir/ir/tir_visitor_with_path.cc                |  21 +-
 src/tir/ir/tir_visitor_with_path.h                 |   8 +-
 src/tir/transform/ir_utils.cc                      |  10 +
 src/tir/transform/simplify.cc                      |  13 ++
 tests/cpp/ir_functor_test.cc                       |  50 ++--
 .../test_tir_analysis_undefined_vars.py            |  93 ++++++++
 .../tir-transform/test_tir_transform_simplify.py   |  45 +++-
 16 files changed, 451 insertions(+), 362 deletions(-)

diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h
index 884f207b97..e86c6bb125 100644
--- a/include/tvm/tir/stmt_functor.h
+++ b/include/tvm/tir/stmt_functor.h
@@ -144,6 +144,20 @@ class TVM_DLL StmtVisitor : protected 
StmtFunctor<void(const Stmt&)> {
    *       and redirect Visit to ExprMutator::VisitExpr(Expr)
    */
   virtual void VisitExpr(const PrimExpr& e) {}
+  /*!
+   * \brief Visit buffer at definition site (AllocBuffer, DeclBuffer, SBlock 
alloc_buffers).
+   *  Visits buffer shape, strides, elem_offset via VisitExpr.
+   * \param buffer The buffer being defined.
+   * \param alloc_data If true, the buffer's data pointer is a new allocation 
(AllocBuffer);
+   *              if false, data references an existing variable (DeclBuffer).
+   */
+  virtual void VisitBufferDef(const Buffer& buffer, bool alloc_data);
+  /*!
+   * \brief Visit buffer at use site (BufferStore, BufferLoad, SBlock 
reads/writes).
+   *  By default, this is a no-op, as buffer fields (shape, strides, 
elem_offset)
+   *  are visited at their definition site.
+   */
+  virtual void VisitBufferUse(const Buffer& buffer);
   // statement visitor
   void VisitStmt_(const AttrStmtNode* op) override;
   void VisitStmt_(const IfThenElseNode* op) override;
@@ -179,6 +193,8 @@ class TVM_DLL StmtMutator : protected 
StmtFunctor<Stmt(const Stmt&)> {
   }
 
  protected:
+  /*! \brief Map from old buffer to new buffer, populated by VisitBufferDef. */
+  ffi::Map<Buffer, Buffer> buffer_remap_;
   // We perform copy on write optimizations on the StmtMutator
   // so that an unique copy of parent can be mutated inplace
   // when some of its children changed.
@@ -240,6 +256,22 @@ class TVM_DLL StmtMutator : protected 
StmtFunctor<Stmt(const Stmt&)> {
    *       and redirect Mutate to ExprMutator::Mutate(Expr)
    */
   virtual PrimExpr VisitExpr(const PrimExpr& e) { return e; }
+  /*!
+   * \brief Visit buffer at definition site. Visits shape/strides/elem_offset 
via VisitExpr.
+   *  If any field changes, creates a new buffer and records it in 
buffer_remap_.
+   * \param buffer The buffer being defined.
+   * \param alloc_data If true, the buffer's data pointer is a new allocation 
(AllocBuffer);
+   *              if false, data references an existing variable (DeclBuffer).
+   * \return The (possibly new) buffer.
+   */
+  virtual Buffer VisitBufferDef(const Buffer& buffer, bool alloc_data);
+  /*!
+   * \brief Visit buffer at use site (BufferStore, BufferLoad, SBlock 
reads/writes).
+   *  By default, returns the remapped buffer from buffer_remap_ if exists, 
otherwise
+   *  returns the original buffer. Buffer fields are visited at their 
definition site.
+   * \return The (possibly remapped) buffer.
+   */
+  virtual Buffer VisitBufferUse(const Buffer& buffer);
   // statement visitor
   Stmt VisitStmt_(const AttrStmtNode* op) override;
   Stmt VisitStmt_(const IfThenElseNode* op) override;
@@ -276,31 +308,35 @@ class TVM_DLL StmtMutator : protected 
StmtFunctor<Stmt(const Stmt&)> {
 /*!
  * \brief Visitor that recursively visit stmts and exprs on them.
  */
-class StmtExprVisitor : public StmtVisitor, public ExprVisitor {
+class StmtExprVisitor : public ExprVisitor, public StmtVisitor {
  public:
   using StmtVisitor::operator();
   using ExprVisitor::operator();
 
  protected:
   using ExprVisitor::VisitExpr;
+  using ExprVisitor::VisitExpr_;
   using StmtVisitor::VisitStmt;
 
   void VisitExpr(const PrimExpr& e) override { return 
ExprVisitor::VisitExpr(e); }
+  void VisitExpr_(const BufferLoadNode* op) override;
 };
 
 /*!
  * \brief Mutator that recursively mutates stmts and exprs on them.
  */
-class StmtExprMutator : public StmtMutator, public ExprMutator {
+class StmtExprMutator : public ExprMutator, public StmtMutator {
  public:
   using StmtMutator::operator();
   using ExprMutator::operator();
 
  protected:
   using ExprMutator::VisitExpr;
-  using StmtMutator::VisitExpr;
+  using ExprMutator::VisitExpr_;
+  using StmtMutator::VisitStmt;
 
   PrimExpr VisitExpr(const PrimExpr& e) override { return 
ExprMutator::VisitExpr(e); }
+  PrimExpr VisitExpr_(const BufferLoadNode* op) override;
 };
 
 /*!
diff --git a/src/s_tir/transform/renew_defs.cc 
b/src/s_tir/transform/renew_defs.cc
index f48cfff701..224fccbadb 100644
--- a/src/s_tir/transform/renew_defs.cc
+++ b/src/s_tir/transform/renew_defs.cc
@@ -71,7 +71,7 @@ class RenewDefMutator : public StmtExprMutator {
       if (param->dtype.is_handle()) {
         const Buffer& buffer = func->buffer_map.at(param);
         Var new_param = Downcast<Var>(generator.VisitExpr(param));
-        Buffer new_buffer = generator.VisitBuffer(buffer, true);
+        Buffer new_buffer = generator.DefineBuffer(buffer);
         buffer_map.Set(new_param, new_buffer);
       }
     }
@@ -102,40 +102,24 @@ class RenewDefMutator : public StmtExprMutator {
   STMT_REGENERATE_VAR_DEF(LetStmtNode, var);
   STMT_REGENERATE_VAR_DEF(ForNode, loop_var);
 
-  Stmt VisitStmt_(const AllocBufferNode* op) final {
-    Buffer new_buffer = VisitBuffer(op->buffer, /*define=*/true);
-    Stmt body = this->VisitStmt(op->body);
-    if (new_buffer.same_as(op->buffer) && body.same_as(op->body)) {
-      return ffi::GetRef<Stmt>(op);
-    } else {
-      auto n = ffi::make_object<AllocBufferNode>(*op);
-      n->buffer = std::move(new_buffer);
-      n->body = std::move(body);
-      return Stmt(n);
-    }
+  // Override VisitBufferDef to create fresh buffer copies at definition sites
+  // (AllocBuffer, DeclBuffer, SBlock alloc_buffers, match_buffers)
+  Buffer VisitBufferDef(const Buffer& buffer, bool alloc_data) final {
+    return DefineBuffer(buffer);
   }
 
-  Stmt VisitStmt_(const DeclBufferNode* op) final {
-    Buffer new_buffer = VisitBuffer(op->buffer, /*define=*/true);
-    Stmt body = this->VisitStmt(op->body);
-    if (new_buffer.same_as(op->buffer) && body.same_as(op->body)) {
-      return ffi::GetRef<Stmt>(op);
-    } else {
-      auto n = ffi::make_object<DeclBufferNode>(*op);
-      n->buffer = std::move(new_buffer);
-      n->body = std::move(body);
-      return Stmt(n);
-    }
-  }
+  // Override VisitBufferUse to remap buffers at use sites
+  // (BufferStore, BufferLoad, SBlock reads/writes)
+  Buffer VisitBufferUse(const Buffer& buffer) final { return 
UseOrRemapBuffer(buffer); }
 
   Stmt VisitStmt_(const SBlockNode* op) final {
     // Step 0. Re-define Itervars
     ffi::Array<IterVar> iter_vars =
         op->iter_vars.Map(std::bind(&RenewDefMutator::VisitIterVar, this, 
std::placeholders::_1));
 
-    // Step 1. Re-define buffers allocate under the block
-    ffi::Array<Buffer> alloc_buffers = op->alloc_buffers.Map(
-        std::bind(&RenewDefMutator::VisitBuffer, this, std::placeholders::_1, 
/*define=*/true));
+    // Step 1. Re-define buffers allocated under the block
+    ffi::Array<Buffer> alloc_buffers =
+        op->alloc_buffers.Map([this](const Buffer& buf) { return 
this->DefineBuffer(buf); });
 
     // Step 2. Re-define match_buffers
     ffi::Array<MatchBufferRegion> match_buffers = op->match_buffers.Map(
@@ -167,34 +151,6 @@ class RenewDefMutator : public StmtExprMutator {
     return Stmt(n);
   }
 
-  Stmt VisitStmt_(const BufferStoreNode* op) final {
-    Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<BufferStoreNode>();
-    TVM_FFI_ICHECK(op != nullptr);
-    Buffer buffer = VisitDeclOrRemapBuffer(op->buffer);
-    if (buffer.same_as(op->buffer)) {
-      return stmt;
-    } else {
-      auto n = ffi::make_object<BufferStoreNode>(*op);
-      n->buffer = std::move(buffer);
-      return BufferStore(n);
-    }
-  }
-
-  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
-    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<BufferLoadNode>();
-    TVM_FFI_ICHECK(op != nullptr);
-    Buffer buffer = VisitDeclOrRemapBuffer(op->buffer);
-    if (buffer.same_as(op->buffer)) {
-      return expr;
-    } else {
-      auto n = ffi::make_object<BufferLoadNode>(*op);
-      n->buffer = std::move(buffer);
-      return BufferLoad(n);
-    }
-  }
-
  private:
   Var ReDefineVar(const Var& var) {
     Var new_var = Var(ffi::make_object<VarNode>(*var.get()));
@@ -208,12 +164,11 @@ class RenewDefMutator : public StmtExprMutator {
     remap_.Set(source, target);
   }
 
-  Buffer VisitBuffer(const Buffer& buffer, bool define = false) {
+  Buffer DefineBuffer(const Buffer& buffer) {
     auto it = remap_.find(buffer);
     if (it != remap_.end()) {
       return Downcast<Buffer>((*it).second);
     }
-    TVM_FFI_ICHECK(define);
 
     auto redefine_if_is_var = [this](const PrimExpr& expr) -> PrimExpr {
       auto it = remap_.find(expr);
@@ -247,24 +202,9 @@ class RenewDefMutator : public StmtExprMutator {
     return new_buffer;
   }
 
-  IterVar VisitIterVar(const IterVar& iter_var) {
-    auto it = remap_.find(iter_var);
-    if (it != remap_.end()) {
-      return Downcast<IterVar>((*it).second);
-    }
-    PrimExpr min = VisitExpr(iter_var->dom->min);
-    PrimExpr extent = VisitExpr(iter_var->dom->extent);
-    IterVar new_iter_var(Range(min, extent), ReDefineVar(iter_var->var), 
iter_var->iter_type,
-                         iter_var->thread_tag);
-    this->AddDefRemap(iter_var, new_iter_var);
-    return new_iter_var;
-  }
-
-  Buffer VisitDeclOrRemapBuffer(const Buffer& buffer) {
+  Buffer UseOrRemapBuffer(const Buffer& buffer) {
     // If the buffer has been remapped, return the remapped buffer, otherwise,
-    // return the declared one.
-    // Due to a recent PR, we can allow undefined buffer appearing in 
BufferLoad/Store. We need
-    // to remap them but will not create new var
+    // remap it without creating new var definitions.
     auto it = remap_.find(buffer);
     if (it != remap_.end()) {
       return Downcast<Buffer>((*it).second);
@@ -286,8 +226,21 @@ class RenewDefMutator : public StmtExprMutator {
     return new_buffer;
   }
 
+  IterVar VisitIterVar(const IterVar& iter_var) {
+    auto it = remap_.find(iter_var);
+    if (it != remap_.end()) {
+      return Downcast<IterVar>((*it).second);
+    }
+    PrimExpr min = VisitExpr(iter_var->dom->min);
+    PrimExpr extent = VisitExpr(iter_var->dom->extent);
+    IterVar new_iter_var(Range(min, extent), ReDefineVar(iter_var->var), 
iter_var->iter_type,
+                         iter_var->thread_tag);
+    this->AddDefRemap(iter_var, new_iter_var);
+    return new_iter_var;
+  }
+
   MatchBufferRegion VisitMatchBuffer(const MatchBufferRegion& match_buffer) {
-    Buffer buffer = VisitBuffer(match_buffer->buffer, /*define=*/true);
+    Buffer buffer = DefineBuffer(match_buffer->buffer);
     BufferRegion region = VisitBufferRegion(match_buffer->source);
     return MatchBufferRegion(std::move(buffer), std::move(region));
   }
@@ -303,7 +256,7 @@ class RenewDefMutator : public StmtExprMutator {
   }
 
   BufferRegion VisitBufferRegion(const BufferRegion& buffer_region) {
-    Buffer buffer = VisitBuffer(buffer_region->buffer);
+    Buffer buffer = UseOrRemapBuffer(buffer_region->buffer);
     ffi::Array<Range> region = buffer_region->region.Map(
         std::bind(&RenewDefMutator::VisitRange, this, std::placeholders::_1));
     if (buffer.same_as(buffer_region->buffer) && 
region.same_as(buffer_region->region)) {
diff --git a/src/tir/analysis/var_use_def_analysis.cc 
b/src/tir/analysis/var_use_def_analysis.cc
index 9c4d26b58c..b2236e28ce 100644
--- a/src/tir/analysis/var_use_def_analysis.cc
+++ b/src/tir/analysis/var_use_def_analysis.cc
@@ -64,21 +64,8 @@ void VarUseDefAnalyzer::VisitStmt_(const ForNode* op) {
   StmtExprVisitor::VisitStmt_(op);
 }
 
-void VarUseDefAnalyzer::VisitStmt_(const DeclBufferNode* op) {
-  this->HandleDef(op->buffer);
-  StmtExprVisitor::VisitStmt_(op);
-}
-
 void VarUseDefAnalyzer::VisitStmt_(const AllocBufferNode* op) {
-  // AllocBuffer both allocates the data variable and declares the buffer,
-  // so we must define buffer->data before the buffer itself.
-  this->HandleDef(op->buffer->data);
-  this->HandleDef(op->buffer);
-  StmtExprVisitor::VisitStmt_(op);
-}
-
-void VarUseDefAnalyzer::VisitStmt_(const BufferStoreNode* op) {
-  HandleUse(op->buffer);
+  // VisitBufferDef (called by base) defines buffer->data and the buffer 
itself.
   StmtExprVisitor::VisitStmt_(op);
 }
 
@@ -113,9 +100,29 @@ void VarUseDefAnalyzer::VisitExpr_(const ReduceNode* op) {
   StmtExprVisitor::VisitExpr_(op);
 }
 
-void VarUseDefAnalyzer::VisitExpr_(const BufferLoadNode* op) {
-  HandleUse(op->buffer);
-  StmtExprVisitor::VisitExpr_(op);
+void VarUseDefAnalyzer::VisitBufferDef(const Buffer& buffer, bool alloc_data) {
+  if (alloc_data) {
+    // AllocBuffer / SBlock: data is a new allocation — define it.
+    if (!use_count_.count(buffer->data.get())) {
+      HandleDef(buffer->data);
+    }
+  } else {
+    // DeclBuffer: data references an existing variable — use it.
+    HandleUse(buffer->data);
+  }
+  HandleDef(buffer);
+  // Visit shape/strides/elem_offset as uses of vars from the enclosing scope.
+  for (const auto& e : buffer->shape) this->VisitExpr(e);
+  for (const auto& e : buffer->strides) this->VisitExpr(e);
+  this->VisitExpr(buffer->elem_offset);
+}
+
+void VarUseDefAnalyzer::VisitBufferUse(const Buffer& buffer) {
+  HandleUse(buffer);
+  // Buffer data pointer must be tracked as a use — the use site
+  // reads/writes through this pointer.  Without this, UndefinedVars
+  // misses data vars for buffers whose DeclBuffer is outside the scope.
+  HandleUse(buffer->data);
 }
 
 void VarUseDefAnalyzer::VisitBuffer(const Buffer& buffer) {
@@ -162,8 +169,8 @@ void VarUseDefAnalyzer::HandleDef(const Buffer& buf) {
       << "buffer " << ptr->name << " has been used before definition!";
   buffer_use_count_[ptr] = 0;
   buffer_def_count_[ptr] = 1;
-
-  VisitBuffer(buf);
+  // Buffer fields (data, shape, strides) are visited by the caller
+  // (VisitBufferDef) via the base class, not here.
 }
 
 void VarUseDefAnalyzer::HandleUse(const Buffer& buf) {
@@ -177,8 +184,9 @@ void VarUseDefAnalyzer::HandleUse(const Buffer& buf) {
     undefined_buffers_.push_back(ffi::GetRef<Buffer>(ptr));
     buffer_use_count_[ptr] = -1;
   }
-
-  VisitBuffer(buf);
+  // Buffer fields (shape, strides, data) are visited at the definition
+  // site via VisitBufferDef.  Do not re-visit them at use sites, as the
+  // buffer's shape variables may not be in scope at the point of use.
 }
 
 ffi::Array<Var> UndefinedVars(const Stmt& stmt, const ffi::Array<Var>& args) {
diff --git a/src/tir/analysis/var_use_def_analysis.h 
b/src/tir/analysis/var_use_def_analysis.h
index 1c089c048f..7196fb2e8f 100644
--- a/src/tir/analysis/var_use_def_analysis.h
+++ b/src/tir/analysis/var_use_def_analysis.h
@@ -61,19 +61,19 @@ class VarUseDefAnalyzer : public StmtExprVisitor {
 
   void VisitStmt_(const ForNode* op) final;
 
-  void VisitStmt_(const DeclBufferNode* op) final;
-
   void VisitStmt_(const AllocBufferNode* op) final;
 
-  void VisitStmt_(const BufferStoreNode* op) final;
-
   void VisitExpr_(const LetNode* op) final;
 
   void VisitExpr_(const VarNode* op) final;
 
   void VisitExpr_(const ReduceNode* op) final;
 
-  void VisitExpr_(const BufferLoadNode* op) final;
+  // Piggyback on base class VisitBufferDef/VisitBufferUse to handle buffer
+  // def/use tracking. Base class calls these from AllocBuffer, DeclBuffer,
+  // BufferStore, BufferLoad, and SBlock visitors.
+  void VisitBufferDef(const Buffer& buffer, bool alloc_data) final;
+  void VisitBufferUse(const Buffer& buffer) final;
 
   void HandleDef(const Var& v);
   void HandleUse(const Var& v);
diff --git a/src/tir/analysis/verify_well_formed.cc 
b/src/tir/analysis/verify_well_formed.cc
index 0ff363a547..5d8a2d5778 100644
--- a/src/tir/analysis/verify_well_formed.cc
+++ b/src/tir/analysis/verify_well_formed.cc
@@ -348,7 +348,7 @@ class UndefinedBufferVerifier : public 
Verifier<UndefinedBufferVerifier> {
     previously_defined_.insert({buffer, path});
   }
 
-  void Visit(const Buffer& buffer, AccessPath path) override {
+  void VisitBufferUse(const Buffer& buffer, AccessPath path) override {
     bool is_declared = currently_defined_.count(buffer);
     bool was_declared = previously_defined_.count(buffer);
 
@@ -363,8 +363,8 @@ class UndefinedBufferVerifier : public 
Verifier<UndefinedBufferVerifier> {
       Verify(false) << "TIR is ill-formed: buffer " << buffer->name << " is 
used at " << path
                     << " without a prior DeclBuffer or other declaration.";
     }
-    // Still visit the buffer's internal vars so variable usage is tracked.
-    Verifier::Visit(buffer, path);
+    // Buffer fields are visited at definition site (EnterDef), not here.
+    Verifier::VisitBufferUse(buffer, path);
   }
 
   // Buffers defined in the currently-visited scope.
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index c4d83c2236..f7c508dac9 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -277,22 +277,16 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const 
AttrStmtNode* op) {
   return DataTypeLegalizer::VisitStmt_(op);
 }
 
-Stmt IndexDataTypeRewriter::VisitStmt_(const AllocBufferNode* op) {
-  Buffer new_buffer = VisitBuffer(op->buffer);
-  AllocBuffer alloc_buffer = 
Downcast<AllocBuffer>(StmtExprMutator::VisitStmt_(op));
-  if (!new_buffer.same_as(op->buffer)) {
-    alloc_buffer.CopyOnWrite()->buffer = new_buffer;
-  }
-  return alloc_buffer;
+Buffer IndexDataTypeRewriter::VisitBufferDef(const Buffer& buffer, bool 
alloc_data) {
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+  Buffer new_buf = StmtMutator::VisitBufferDef(buffer, alloc_data);
+  is_enabled_ = is_enabled;
+  return new_buf;
 }
 
-Stmt IndexDataTypeRewriter::VisitStmt_(const DeclBufferNode* op) {
-  Buffer new_buffer = VisitBuffer(op->buffer);
-  DeclBuffer decl_buffer = 
Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
-  if (!new_buffer.same_as(op->buffer)) {
-    decl_buffer.CopyOnWrite()->buffer = new_buffer;
-  }
-  return decl_buffer;
+Buffer IndexDataTypeRewriter::VisitBufferUse(const Buffer& buffer) {
+  return StmtMutator::VisitBufferUse(buffer);
 }
 
 Stmt IndexDataTypeRewriter::VisitStmt_(const SBlockRealizeNode* op) {
@@ -322,11 +316,11 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const 
SBlockRealizeNode* op) {
 }
 
 Stmt IndexDataTypeRewriter::VisitStmt_(const SBlockNode* op) {
-  ffi::Array<Buffer> new_alloc_buffers =
-      op->alloc_buffers.Map([this](const Buffer& buffer) { return 
this->VisitBuffer(buffer); });
+  ffi::Array<Buffer> new_alloc_buffers = op->alloc_buffers.Map(
+      [this](const Buffer& buffer) { return this->VisitBufferDef(buffer, 
/*alloc_data=*/true); });
   ffi::Array<MatchBufferRegion> new_match_buffers =
       op->match_buffers.Map([this](const MatchBufferRegion& 
match_buffer_region) {
-        Buffer new_buffer = this->VisitBuffer(match_buffer_region->buffer);
+        Buffer new_buffer = this->VisitBufferDef(match_buffer_region->buffer, 
/*alloc_data=*/true);
         BufferRegion new_buffer_region = 
this->VisitBufferRegion(match_buffer_region->source);
         if (!new_buffer.same_as(match_buffer_region->buffer) ||
             !new_buffer_region.same_as(match_buffer_region->source)) {
@@ -378,7 +372,7 @@ ffi::Map<ffi::String, ffi::Any> 
IndexDataTypeRewriter::VisitBlockAnnotations(
     }
     if (obj.as<BufferNode>()) {
       Buffer buffer = Downcast<Buffer>(obj);
-      if (Buffer new_buffer = GetRemappedBuffer(buffer); 
!new_buffer.same_as(buffer)) {
+      if (Buffer new_buffer = VisitBufferUse(buffer); 
!new_buffer.same_as(buffer)) {
         return new_buffer;
       }
     } else if (obj.as<ffi::ArrayObj>()) {
@@ -397,13 +391,6 @@ ffi::Map<ffi::String, ffi::Any> 
IndexDataTypeRewriter::VisitBlockAnnotations(
   return new_annotations;
 }
 
-Buffer IndexDataTypeRewriter::GetRemappedBuffer(const Buffer& buffer) {
-  if (auto it = buffer_remap_.find(buffer); it != buffer_remap_.end()) {
-    return (*it).second;
-  }
-  return buffer;
-}
-
 IterVar IndexDataTypeRewriter::VisitIterVar(const IterVar& iter_var) {
   bool is_enabled = is_enabled_;
   is_enabled_ = true;
@@ -422,33 +409,8 @@ IterVar IndexDataTypeRewriter::VisitIterVar(const IterVar& 
iter_var) {
   return iter_var;
 }
 
-Buffer IndexDataTypeRewriter::VisitBuffer(const Buffer& buffer) {
-  bool is_enabled = is_enabled_;
-
-  is_enabled_ = true;
-  ffi::Array<PrimExpr> new_shape =
-      buffer->shape.Map([&](const PrimExpr& e) { return this->VisitExpr(e); });
-  ffi::Array<PrimExpr> new_strides =
-      buffer->strides.Map([&](const PrimExpr& e) { return this->VisitExpr(e); 
});
-  auto new_elem_offset = VisitExpr(buffer->elem_offset);
-  is_enabled_ = is_enabled;
-
-  if (!buffer->shape.same_as(new_shape) || 
!buffer->strides.same_as(new_strides) ||
-      !buffer->elem_offset.same_as(new_elem_offset)) {
-    Buffer new_buffer = buffer;
-    BufferNode* new_buffer_node = new_buffer.CopyOnWrite();
-    new_buffer_node->shape = std::move(new_shape);
-    new_buffer_node->strides = std::move(new_strides);
-    new_buffer_node->elem_offset = std::move(new_elem_offset);
-    buffer_remap_.Set(buffer, new_buffer);
-    return new_buffer;
-  } else {
-    return buffer;
-  }
-}
-
 BufferRegion IndexDataTypeRewriter::VisitBufferRegion(const BufferRegion& 
buffer_region) {
-  Buffer remapped_buffer = GetRemappedBuffer(buffer_region->buffer);
+  Buffer remapped_buffer = VisitBufferUse(buffer_region->buffer);
 
   bool is_enabled = is_enabled_;
   is_enabled_ = true;
@@ -468,7 +430,7 @@ BufferRegion IndexDataTypeRewriter::VisitBufferRegion(const 
BufferRegion& buffer
 Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) {
   BufferStore store = ffi::GetRef<BufferStore>(op);
 
-  Buffer new_buffer = GetRemappedBuffer(op->buffer);
+  Buffer new_buffer = VisitBufferUse(op->buffer);
   auto value = this->VisitExpr(op->value);
   if (new_buffer->dtype != value->dtype && value->dtype.is_scalar()) {
     value = cast(new_buffer->dtype, value);
@@ -489,7 +451,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const 
BufferStoreNode* op) {
 PrimExpr IndexDataTypeRewriter::VisitExpr_(const BufferLoadNode* op) {
   BufferLoad load = ffi::GetRef<BufferLoad>(op);
 
-  Buffer new_buffer = GetRemappedBuffer(op->buffer);
+  Buffer new_buffer = VisitBufferUse(op->buffer);
   auto indices = VisitIndices(op->indices);
 
   if (!new_buffer.same_as(op->buffer) || !indices.same_as(op->indices)) {
@@ -642,7 +604,7 @@ PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) {
   // start rewrite
   ffi::Map<Var, Buffer> new_buffer_map = func->buffer_map;
   for (const auto& [var, buffer] : func->buffer_map) {
-    new_buffer_map.Set(var, VisitBuffer(buffer));
+    new_buffer_map.Set(var, VisitBufferDef(buffer, /*alloc_data=*/true));
   }
   // remap params
   bool is_enabled = true;
diff --git a/src/tir/ir/data_type_rewriter.h b/src/tir/ir/data_type_rewriter.h
index 8196485273..e886777096 100644
--- a/src/tir/ir/data_type_rewriter.h
+++ b/src/tir/ir/data_type_rewriter.h
@@ -101,6 +101,8 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
   using Parent::VisitExpr_;
   using Parent::VisitStmt_;
 
+  Buffer VisitBufferDef(const Buffer& buffer, bool alloc_data) override;
+  Buffer VisitBufferUse(const Buffer& buffer) override;
   Stmt VisitStmt_(const SBlockRealizeNode* op) override;
   Stmt VisitStmt_(const SBlockNode* op) override;
   Stmt VisitStmt_(const BufferStoreNode* op) override;
@@ -108,8 +110,6 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
   PrimExpr VisitExpr_(const BufferLoadNode* op) override;
   ffi::Array<PrimExpr> VisitIndices(ffi::Array<PrimExpr> indices);
   Stmt VisitStmt_(const IfThenElseNode* op) override;
-  Stmt VisitStmt_(const DeclBufferNode* op) override;
-  Stmt VisitStmt_(const AllocBufferNode* op) override;
   Stmt VisitStmt_(const LetStmtNode* op) override;
   PrimExpr VisitExpr_(const EQNode* op) override;
   PrimExpr VisitExpr_(const NENode* op) override;
@@ -122,8 +122,6 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
 
   Stmt VisitStmt_(const ForNode* op) override;
 
-  Buffer VisitBuffer(const Buffer& buffer);
-  Buffer GetRemappedBuffer(const Buffer& buffer);
   ffi::Map<ffi::String, ffi::Any> VisitBlockAnnotations(
       const ffi::Map<ffi::String, ffi::Any>& annotations);
   BufferRegion VisitBufferRegion(const BufferRegion& region);
@@ -132,8 +130,6 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
   bool is_enabled_{false};
   // indicator of condition
   bool is_condition_{false};
-
-  ffi::Map<Buffer, Buffer> buffer_remap_;
 };
 
 /*!
diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc
index 6c4f4666ef..1ad0749711 100644
--- a/src/tir/ir/specialize.cc
+++ b/src/tir/ir/specialize.cc
@@ -177,35 +177,8 @@ class PrimFuncSpecializer : public StmtExprMutator {
     return stmt;
   }
 
-  Stmt VisitStmt_(const BufferStoreNode* op) final {
-    Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<BufferStoreNode>();
-    TVM_FFI_ICHECK(op != nullptr);
-
-    auto new_buf = GetNewBuffer(op->buffer);
-    if (new_buf.same_as(op->buffer)) {
-      return ffi::GetRef<BufferStore>(op);
-    } else {
-      auto n = CopyOnWrite(op);
-      n->buffer = new_buf;
-      return Stmt(n);
-    }
-  }
-
-  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
-    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
-    op = expr.as<BufferLoadNode>();
-    TVM_FFI_ICHECK(op != nullptr);
-
-    auto new_buf = GetNewBuffer(op->buffer);
-    if (new_buf.same_as(op->buffer)) {
-      return ffi::GetRef<BufferLoad>(op);
-    } else {
-      auto n = ffi::make_object<BufferLoadNode>(*op);
-      n->buffer = new_buf;
-      return PrimExpr(n);
-    }
-  }
+  // Override VisitBufferUse to use our own buffer_map_ instead of base class 
field visiting.
+  Buffer VisitBufferUse(const Buffer& buffer) final { return 
GetNewBuffer(buffer); }
 
   PrimExpr VisitExpr_(const VarNode* op) final {
     auto it = var_map_.find(ffi::GetRef<Var>(op));
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index ce26337e15..ff79c374db 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -57,11 +57,35 @@ void StmtVisitor::VisitStmt_(const WhileNode* op) {
   this->VisitStmt(op->body);
 }
 
-void StmtVisitor::VisitStmt_(const AllocBufferNode* op) { 
this->VisitStmt(op->body); }
+void StmtVisitor::VisitBufferDef(const Buffer& buffer, bool alloc_data) {
+  for (const auto& e : buffer->shape) this->VisitExpr(e);
+  for (const auto& e : buffer->strides) this->VisitExpr(e);
+  this->VisitExpr(buffer->elem_offset);
+}
+
+// Default VisitBufferUse is empty: buffer fields (shape, strides, elem_offset)
+// are visited at the definition site (VisitBufferDef) and should not be
+// re-visited at each use site, as the use site may be in a different scope
+// where the buffer's shape variables are not defined.
+void StmtVisitor::VisitBufferUse(const Buffer& buffer) {}
+
+void StmtExprVisitor::VisitExpr_(const BufferLoadNode* op) {
+  this->VisitBufferUse(op->buffer);
+  ExprVisitor::VisitExpr_(op);
+}
 
-void StmtVisitor::VisitStmt_(const DeclBufferNode* op) { 
this->VisitStmt(op->body); }
+void StmtVisitor::VisitStmt_(const AllocBufferNode* op) {
+  this->VisitBufferDef(op->buffer, /*alloc_data=*/true);
+  this->VisitStmt(op->body);
+}
+
+void StmtVisitor::VisitStmt_(const DeclBufferNode* op) {
+  this->VisitBufferDef(op->buffer, /*alloc_data=*/false);
+  this->VisitStmt(op->body);
+}
 
 void StmtVisitor::VisitStmt_(const BufferStoreNode* op) {
+  this->VisitBufferUse(op->buffer);
   this->VisitExpr(op->value);
   VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
 }
@@ -88,6 +112,7 @@ void StmtVisitor::VisitStmt_(const EvaluateNode* op) { 
this->VisitExpr(op->value
 
 void StmtVisitor::VisitStmt_(const SBlockNode* op) {
   auto fvisit_buffer_region = [this](const BufferRegion& s) {
+    this->VisitBufferUse(s->buffer);
     for (const auto& range : s->region) {
       this->VisitExpr(range->min);
       this->VisitExpr(range->extent);
@@ -97,10 +122,13 @@ void StmtVisitor::VisitStmt_(const SBlockNode* op) {
     this->VisitExpr(iter_var->dom->min);
     this->VisitExpr(iter_var->dom->extent);
   });
+  VisitArray(op->alloc_buffers,
+             [this](const Buffer& buf) { this->VisitBufferDef(buf, 
/*alloc_data=*/true); });
   VisitArray(op->reads, fvisit_buffer_region);
   VisitArray(op->writes, fvisit_buffer_region);
   VisitArray(op->match_buffers,
-             [fvisit_buffer_region](const MatchBufferRegion& 
match_buffer_region) {
+             [this, &fvisit_buffer_region](const MatchBufferRegion& 
match_buffer_region) {
+               this->VisitBufferDef(match_buffer_region->buffer, 
/*alloc_data=*/true);
                fvisit_buffer_region(match_buffer_region->source);
              });
   if (op->init.defined()) {
@@ -191,11 +219,12 @@ class StmtMutator::Internal {
 
   static ffi::Array<BufferRegion> Mutate(StmtMutator* self, const 
ffi::Array<BufferRegion>& arr) {
     auto fmutate = [self](const BufferRegion& buffer_region) {
+      Buffer new_buf = self->VisitBufferUse(buffer_region->buffer);
       ffi::Array<Range> region = Mutate(self, buffer_region->region);
-      if (region.same_as(buffer_region->region)) {
+      if (new_buf.same_as(buffer_region->buffer) && 
region.same_as(buffer_region->region)) {
         return buffer_region;
       } else {
-        return BufferRegion(buffer_region->buffer, region);
+        return BufferRegion(std::move(new_buf), std::move(region));
       }
     };
     return MutateArray(self, arr, fmutate);
@@ -204,12 +233,16 @@ class StmtMutator::Internal {
   static ffi::Array<MatchBufferRegion> Mutate(StmtMutator* self,
                                               const 
ffi::Array<MatchBufferRegion>& arr) {
     auto fmutate = [self](const MatchBufferRegion& match_buffer_region) {
+      Buffer new_buf = self->VisitBufferDef(match_buffer_region->buffer, 
/*alloc_data=*/true);
+      Buffer new_source_buf = 
self->VisitBufferUse(match_buffer_region->source->buffer);
       ffi::Array<Range> region = Mutate(self, 
match_buffer_region->source->region);
-      if (region.same_as(match_buffer_region->source->region)) {
+      if (new_buf.same_as(match_buffer_region->buffer) &&
+          new_source_buf.same_as(match_buffer_region->source->buffer) &&
+          region.same_as(match_buffer_region->source->region)) {
         return match_buffer_region;
       } else {
-        return MatchBufferRegion(match_buffer_region->buffer,
-                                 
BufferRegion(match_buffer_region->source->buffer, region));
+        return MatchBufferRegion(std::move(new_buf),
+                                 BufferRegion(std::move(new_source_buf), 
std::move(region)));
       }
     };
     return MutateArray(self, arr, fmutate);
@@ -276,25 +309,74 @@ Stmt StmtMutator::VisitStmt_(const WhileNode* op) {
   }
 }
 
+Buffer StmtMutator::VisitBufferDef(const Buffer& buffer, bool alloc_data) {
+  if (auto it = buffer_remap_.find(buffer); it != buffer_remap_.end()) {
+    return (*it).second;
+  }
+
+  // Visit expression fields (shape, strides, elem_offset) but NOT data.
+  // data is a Var definition owned by this buffer, not an expression use.
+  // Subclasses that need to remap data (e.g., IRSubstitute) can override.
+  auto shape = buffer->shape.Map([this](const PrimExpr& e) { return 
this->VisitExpr(e); });
+  auto strides = buffer->strides.Map([this](const PrimExpr& e) { return 
this->VisitExpr(e); });
+  PrimExpr elem_offset = this->VisitExpr(buffer->elem_offset);
+
+  if (shape.same_as(buffer->shape) && strides.same_as(buffer->strides) &&
+      elem_offset.same_as(buffer->elem_offset)) {
+    return buffer;
+  }
+  Buffer new_buf = buffer;
+  auto* n = new_buf.CopyOnWrite();
+  n->shape = std::move(shape);
+  n->strides = std::move(strides);
+  n->elem_offset = std::move(elem_offset);
+  buffer_remap_.Set(buffer, new_buf);
+  return new_buf;
+}
+
+Buffer StmtMutator::VisitBufferUse(const Buffer& buffer) {
+  if (auto it = buffer_remap_.find(buffer); it != buffer_remap_.end()) {
+    return (*it).second;
+  }
+  return buffer;
+}
+
+PrimExpr StmtExprMutator::VisitExpr_(const BufferLoadNode* op) {
+  Buffer new_buf = this->VisitBufferUse(op->buffer);
+  PrimExpr expr = ExprMutator::VisitExpr_(op);
+  op = expr.as<BufferLoadNode>();
+  TVM_FFI_ICHECK(op != nullptr);
+  if (!new_buf.same_as(op->buffer)) {
+    auto n = ffi::make_object<BufferLoadNode>(*op);
+    n->buffer = std::move(new_buf);
+    return PrimExpr(n);
+  }
+  return expr;
+}
+
 Stmt StmtMutator::VisitStmt_(const AllocBufferNode* op) {
+  Buffer new_buf = this->VisitBufferDef(op->buffer, /*alloc_data=*/true);
   Stmt body = this->VisitStmt(op->body);
 
-  if (body.same_as(op->body)) {
+  if (new_buf.same_as(op->buffer) && body.same_as(op->body)) {
     return ffi::GetRef<Stmt>(op);
   } else {
     auto n = CopyOnWrite(op);
+    n->buffer = std::move(new_buf);
     n->body = std::move(body);
     return Stmt(n);
   }
 }
 
 Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) {
+  Buffer new_buf = this->VisitBufferDef(op->buffer, /*alloc_data=*/false);
   Stmt body = this->VisitStmt(op->body);
 
-  if (body.same_as(op->body)) {
+  if (new_buf.same_as(op->buffer) && body.same_as(op->body)) {
     return ffi::GetRef<Stmt>(op);
   } else {
     auto n = CopyOnWrite(op);
+    n->buffer = std::move(new_buf);
     n->body = std::move(body);
     return Stmt(n);
   }
@@ -320,13 +402,15 @@ Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) {
 }
 
 Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) {
+  Buffer new_buf = this->VisitBufferUse(op->buffer);
   PrimExpr value = this->VisitExpr(op->value);
   ffi::Array<PrimExpr> indices = Internal::Mutate(this, op->indices);
 
-  if (value.same_as(op->value) && indices.same_as(op->indices)) {
+  if (new_buf.same_as(op->buffer) && value.same_as(op->value) && 
indices.same_as(op->indices)) {
     return ffi::GetRef<Stmt>(op);
   } else {
     auto n = CopyOnWrite(op);
+    n->buffer = std::move(new_buf);
     n->value = std::move(value);
     n->indices = std::move(indices);
     return Stmt(n);
@@ -419,6 +503,9 @@ Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) {
 
 Stmt StmtMutator::VisitStmt_(const SBlockNode* op) {
   ffi::Array<IterVar> iter_vars = Internal::Mutate(this, op->iter_vars);
+  ffi::Array<Buffer> alloc_buffers = Internal::MutateArray(
+      this, op->alloc_buffers,
+      [this](const Buffer& buf) { return this->VisitBufferDef(buf, 
/*alloc_data=*/true); });
   ffi::Array<BufferRegion> reads = Internal::Mutate(this, op->reads);
   ffi::Array<BufferRegion> writes = Internal::Mutate(this, op->writes);
   ffi::Array<MatchBufferRegion> match_buffers = Internal::Mutate(this, 
op->match_buffers);
@@ -427,13 +514,14 @@ Stmt StmtMutator::VisitStmt_(const SBlockNode* op) {
     init = VisitStmt(op->init.value());
   }
   Stmt body = VisitStmt(op->body);
-  if (iter_vars.same_as(op->iter_vars) && reads.same_as(op->reads) && 
writes.same_as(op->writes) &&
-      body.same_as(op->body) && init.same_as(op->init) &&
-      match_buffers.same_as(op->match_buffers)) {
+  if (iter_vars.same_as(op->iter_vars) && 
alloc_buffers.same_as(op->alloc_buffers) &&
+      reads.same_as(op->reads) && writes.same_as(op->writes) && 
body.same_as(op->body) &&
+      init.same_as(op->init) && match_buffers.same_as(op->match_buffers)) {
     return ffi::GetRef<SBlock>(op);
   } else {
     auto n = CopyOnWrite(op);
     n->iter_vars = std::move(iter_vars);
+    n->alloc_buffers = std::move(alloc_buffers);
     n->reads = std::move(reads);
     n->writes = std::move(writes);
     n->body = std::move(body);
@@ -477,6 +565,9 @@ class IRApplyVisit : public StmtExprVisitor {
     f_(node);
   }
 
+  void VisitBufferDef(const Buffer& buffer, bool alloc_data) override {}
+  void VisitBufferUse(const Buffer& buffer) override {}
+
  private:
   std::function<void(const ObjectRef&)> f_;
   std::unordered_set<const Object*> visited_;
@@ -568,68 +659,23 @@ class IRSubstitute : public StmtExprMutator {
     return var;
   }
 
-  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
-    auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
-    return VisitBufferAccess(std::move(node));
-  }
-
-  Stmt VisitStmt_(const BufferStoreNode* op) final {
-    auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
-    return VisitBufferAccess(std::move(node));
-  }
-
-  Stmt VisitStmt_(const AllocBufferNode* op) final {
-    auto node = Downcast<AllocBuffer>(StmtExprMutator::VisitStmt_(op));
-    return VisitBufferAccess(std::move(node));
-  }
-
-  Stmt VisitStmt_(const DeclBufferNode* op) final {
-    auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
-    return VisitBufferAccess(std::move(node));
-  }
-
-  template <typename Node>
-  Node VisitBufferAccess(Node node) {
-    Buffer new_buf = GetRemappedBuffer(node->buffer);
-
-    if (!new_buf.same_as(node->buffer)) {
-      auto writer = node.CopyOnWrite();
-      writer->buffer = new_buf;
-    }
-
-    return node;
-  }
-
-  Buffer GetRemappedBuffer(Buffer buf) {
-    auto key = buf.get();
-    auto it = buf_remap_.find(key);
-    if (it != buf_remap_.end()) {
-      return it->second;
-    }
-
-    PrimExpr new_buffer_var_expr = VisitExpr(buf->data);
-    TVM_FFI_ICHECK(new_buffer_var_expr->IsInstance<VarNode>())
-        << "Buffer " << buf << " uses backing allocation " << buf->data
-        << ", which was substituted into the expression " << 
new_buffer_var_expr << ".  "
-        << "However, this expression is of type " << 
new_buffer_var_expr->GetTypeKey()
+  // Override VisitBufferDef to also remap buffer->data (the backing 
allocation var).
+  // The base class only visits shape/strides/elem_offset.
+  Buffer VisitBufferDef(const Buffer& buffer, bool alloc_data) final {
+    Buffer new_buf = StmtExprMutator::VisitBufferDef(buffer, alloc_data);
+    // Additionally handle data var substitution (base does not visit data).
+    PrimExpr new_data_expr = VisitExpr(new_buf->data);
+    TVM_FFI_ICHECK(new_data_expr->IsInstance<VarNode>())
+        << "Buffer " << new_buf << " uses backing allocation " << new_buf->data
+        << ", which was substituted into the expression " << new_data_expr
         << " and the backing allocation must be a tir::Var";
-
-    Var buffer_var = Downcast<Var>(new_buffer_var_expr);
-    auto elem_offset = VisitExpr(buf->elem_offset);
-    auto shape = buf->shape.Map([this](const auto& expr) { return 
VisitExpr(expr); });
-    auto strides = buf->strides.Map([this](const auto& expr) { return 
VisitExpr(expr); });
-
-    if (!buffer_var.same_as(buf->data) || 
!elem_offset.same_as(buf->elem_offset) ||
-        !shape.same_as(buf->shape) || !strides.same_as(buf->strides)) {
-      auto writer = buf.CopyOnWrite();
-      writer->data = buffer_var;
-      writer->elem_offset = elem_offset;
-      writer->shape = shape;
-      writer->strides = strides;
+    Var data = Downcast<Var>(new_data_expr);
+    if (!data.same_as(new_buf->data)) {
+      auto* n = new_buf.CopyOnWrite();
+      n->data = std::move(data);
+      buffer_remap_.Set(buffer, new_buf);
     }
-
-    buf_remap_[key] = buf;
-    return buf;
+    return new_buf;
   }
 
   Stmt VisitStmt_(const AttrStmtNode* op) final {
@@ -647,15 +693,6 @@ class IRSubstitute : public StmtExprMutator {
  private:
   // Caller provided function that defines the variables to be remapped.
   std::function<ffi::Optional<PrimExpr>(const Var&)> vmap_;
-
-  /* \brief Generated map to track buffers being remapped.
-   *
-   * If a `Var BufferNode::data` is remapped, then all buffers
-   * containing that data pointer should also be remapped.  This map
-   * is used to track buffer modifications, and ensure all instances
-   * of a buffer are replaced by the same modified buffer object.
-   */
-  std::unordered_map<const BufferNode*, Buffer> buf_remap_;
 };
 
 Stmt Substitute(Stmt stmt, std::function<ffi::Optional<PrimExpr>(const Var&)> 
vmap) {
@@ -726,45 +763,6 @@ class IRSubstituteWithDataTypeLegalization : public 
DataTypeLegalizer {
     return var;
   }
 
-  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
-    auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
-    return VisitBufferAccess(std::move(node));
-  }
-
-  Stmt VisitStmt_(const BufferStoreNode* op) final {
-    auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
-    return VisitBufferAccess(std::move(node));
-  }
-
-  template <typename Node>
-  Node VisitBufferAccess(Node node) {
-    Buffer new_buf = GetRemappedBuffer(node->buffer);
-
-    if (!new_buf.same_as(node->buffer)) {
-      auto writer = node.CopyOnWrite();
-      writer->buffer = new_buf;
-    }
-
-    return node;
-  }
-
-  Buffer GetRemappedBuffer(Buffer buf) {
-    auto key = buf.get();
-    auto it = buf_remap_.find(key);
-    if (it != buf_remap_.end()) {
-      return it->second;
-    }
-
-    auto new_buffer_var = vmap_(buf->data);
-    if (new_buffer_var.defined() && 
!new_buffer_var.value().same_as(buf->data)) {
-      auto writer = buf.CopyOnWrite();
-      writer->data = Downcast<Var>(new_buffer_var);
-    }
-
-    buf_remap_[key] = buf;
-    return buf;
-  }
-
   Stmt VisitStmt_(const AttrStmtNode* op) final {
     Stmt ret = StmtExprMutator::VisitStmt_(op);
     op = ret.as<AttrStmtNode>();
@@ -780,15 +778,6 @@ class IRSubstituteWithDataTypeLegalization : public 
DataTypeLegalizer {
  private:
   // Caller provided function that defines the variables to be remapped.
   std::function<ffi::Optional<PrimExpr>(const Var&)> vmap_;
-
-  /* \brief Generated map to track buffers being remapped.
-   *
-   * If a `Var BufferNode::data` is remapped, then all buffers
-   * containing that data pointer should also be remapped.  This map
-   * is used to track buffer modifications, and ensure all instances
-   * of a buffer are replaced by the same modified buffer object.
-   */
-  std::unordered_map<const BufferNode*, Buffer> buf_remap_;
 };
 
 Stmt SubstituteWithDataTypeLegalization(Stmt stmt,
diff --git a/src/tir/ir/tir_visitor_with_path.cc 
b/src/tir/ir/tir_visitor_with_path.cc
index 1cf888a542..5436e73d57 100644
--- a/src/tir/ir/tir_visitor_with_path.cc
+++ b/src/tir/ir/tir_visitor_with_path.cc
@@ -129,23 +129,26 @@ void TIRVisitorWithPath::ExitDef(const IterVar& iter_var, 
AccessPath path) {
 void TIRVisitorWithPath::EnterDef(const Buffer& buffer, AccessPath path) {
   // Defining a buffer counts as using all parameters in the buffer
   // (e.g. shape/strides).
-  Visit(buffer->data, path->Attr("data"));
-  Visit(buffer->shape, path->Attr("shape"));
-  Visit(buffer->strides, path->Attr("strides"));
-  Visit(buffer->elem_offset, path->Attr("elem_offset"));
+  VisitBufferDef(buffer, path);
 }
 void TIRVisitorWithPath::ExitDef(const Buffer& buffer, AccessPath path) {}
 
-void TIRVisitorWithPath::Visit(const Buffer& buffer, AccessPath path) {
-  // Using a buffer *also* counts as using all parameters in the buffer.
+void TIRVisitorWithPath::VisitBufferDef(const Buffer& buffer, AccessPath path) 
{
   Visit(buffer->data, path->Attr("data"));
   Visit(buffer->shape, path->Attr("shape"));
   Visit(buffer->strides, path->Attr("strides"));
   Visit(buffer->elem_offset, path->Attr("elem_offset"));
 }
 
+// Default: buffer use sites do not re-visit buffer fields. Buffer fields
+// (shape, strides, elem_offset) are visited at the definition site via
+// VisitBufferDef/EnterDef. Re-visiting at use sites would require those
+// variables to be in scope at every use, which may not hold when buffers
+// are allocated in a different scope than where they are used.
+void TIRVisitorWithPath::VisitBufferUse(const Buffer& buffer, AccessPath path) 
{}
+
 void TIRVisitorWithPath::Visit(const BufferRegion& region, AccessPath path) {
-  Visit(region->buffer, path->Attr("buffer"));
+  VisitBufferUse(region->buffer, path->Attr("buffer"));
   Visit(region->region, path->Attr("region"));
 }
 
@@ -225,7 +228,7 @@ void TIRVisitorWithPath::VisitStmt_(const DeclBufferNode* 
op, AccessPath path) {
 
 void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, AccessPath 
path) {
   Visit(op->value, path->Attr("value"));
-  Visit(op->buffer, path->Attr("buffer"));
+  VisitBufferUse(op->buffer, path->Attr("buffer"));
   Visit(op->indices, path->Attr("indices"));
 }
 
@@ -308,7 +311,7 @@ void TIRVisitorWithPath::VisitExpr_(const SizeVarNode* op, 
AccessPath path) {
 }
 
 void TIRVisitorWithPath::VisitExpr_(const BufferLoadNode* op, AccessPath path) 
{
-  Visit(op->buffer, path->Attr("buffer"));
+  VisitBufferUse(op->buffer, path->Attr("buffer"));
   Visit(op->indices, path->Attr("indices"));
 }
 
diff --git a/src/tir/ir/tir_visitor_with_path.h 
b/src/tir/ir/tir_visitor_with_path.h
index 51e435b47d..f5189ae61c 100644
--- a/src/tir/ir/tir_visitor_with_path.h
+++ b/src/tir/ir/tir_visitor_with_path.h
@@ -53,12 +53,18 @@ class TIRVisitorWithPath
   // Delegate to ExprFunctor::VisitStmt for Stmt, and any subclasses
   inline void Visit(const Stmt& obj, ffi::reflection::AccessPath path) { 
VisitStmt(obj, path); }
 
+  // Visit a buffer at a use site (BufferLoad, BufferStore, reads/writes).
+  // By default, does not re-visit buffer fields (shape, strides, elem_offset),
+  // as those are visited at the definition site via EnterDef.
+  virtual void VisitBufferUse(const Buffer& obj, ffi::reflection::AccessPath 
path);
+  // Visit a buffer at a definition site. By default visits buffer fields.
+  virtual void VisitBufferDef(const Buffer& obj, ffi::reflection::AccessPath 
path);
+
   // Visitors for TIR constructs that are neither PrimExpr nor Stmt
   virtual void Visit(const IRModule& obj, ffi::reflection::AccessPath path);
   virtual void Visit(const PrimFunc& obj, ffi::reflection::AccessPath path);
   virtual void Visit(const GlobalVar& obj, ffi::reflection::AccessPath path) {}
   virtual void Visit(const Range& obj, ffi::reflection::AccessPath path);
-  virtual void Visit(const Buffer& obj, ffi::reflection::AccessPath path);
   virtual void Visit(const BufferRegion& obj, ffi::reflection::AccessPath 
path);
   virtual void Visit(const MatchBufferRegion& obj, ffi::reflection::AccessPath 
path);
   virtual void Visit(const IterVar& obj, ffi::reflection::AccessPath path);
diff --git a/src/tir/transform/ir_utils.cc b/src/tir/transform/ir_utils.cc
index b9d8cdaab7..9398c2561e 100644
--- a/src/tir/transform/ir_utils.cc
+++ b/src/tir/transform/ir_utils.cc
@@ -205,6 +205,16 @@ class IRConvertSSA final : public StmtExprMutator {
     return func;
   }
 
+  // Do not use the base VisitBufferDef for buffer remapping.
+  //
+  // IRConvertSSA has its own scoped buffer remapping via GetRemappedBuffer and
+  // buf_remap_, which handles SSA conversion of buffer data vars, shape, 
strides,
+  // and elem_offset with proper scope tracking. The base 
StmtMutator::VisitBufferDef
+  // would create a conflicting second remap (into base buffer_remap_) when 
called
+  // from the default DeclBuffer/AllocBuffer handlers, producing buffers with
+  // undefined SSA-renamed variables.
+  Buffer VisitBufferDef(const Buffer& buffer, bool alloc_data) override { 
return buffer; }
+
   PrimExpr VisitExpr_(const VarNode* op) final { return 
GetRemappedVar(ffi::GetRef<Var>(op)); }
   PrimExpr VisitExpr_(const LetNode* op) final {
     const Var& v = op->var;
diff --git a/src/tir/transform/simplify.cc b/src/tir/transform/simplify.cc
index f06e52f328..af0fc4cf47 100644
--- a/src/tir/transform/simplify.cc
+++ b/src/tir/transform/simplify.cc
@@ -182,6 +182,19 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
   using Parent::VisitStmt;
   using Parent::VisitStmt_;
 
+  // Do not simplify buffer definition fields (shape, strides, elem_offset).
+  //
+  // The simplifier's VisitExpr override calls analyzer_->Simplify() directly,
+  // bypassing the normal ExprMutator dispatch. This means BufferLoad 
expressions
+  // inside values (e.g., BufferStore value) skip VisitExpr_(BufferLoadNode*) 
and
+  // thus skip VisitBufferUse. If VisitBufferDef remaps buffers at DeclBuffer 
sites,
+  // the BufferLoad use sites won't pick up the remap, causing 
DeclBuffer/BufferLoad
+  // buffer identity divergence and well-formedness violations.
+  //
+  // Instead, we keep buffer definitions unchanged and rely on 
used_in_buffer_def_
+  // to prevent inlining LetStmt vars that appear in buffer definitions.
+  Buffer VisitBufferDef(const Buffer& buffer, bool alloc_data) override { 
return buffer; }
+
   PrimExpr VisitExpr(const PrimExpr& expr) final {
     if (config_->propagate_knowns_to_simplify_expressions) {
       return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(), 
analyzer_);
diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc
index b36420dfe2..3375be981c 100644
--- a/tests/cpp/ir_functor_test.cc
+++ b/tests/cpp/ir_functor_test.cc
@@ -157,7 +157,9 @@ TEST(IRF, StmtVisitor) {
     return AllocBuffer(buf, body);
   };
   v(fmaketest());
-  TVM_FFI_ICHECK_EQ(v.count, 1);
+  // AllocBuffer now visits buffer shape via VisitBufferDef.
+  // shape = {z, z} where z = x + 1, so x is visited twice from shape + once 
from body = 3
+  TVM_FFI_ICHECK_EQ(v.count, 3);
 
   {
     // tests for block and block_realize
@@ -176,7 +178,10 @@ TEST(IRF, StmtVisitor) {
 
     v.count = 0;
     v(block_realize);
-    TVM_FFI_ICHECK_EQ(v.count, 5);
+    // Old count was 5 (x visited in reads/writes/match_buffers range + init 
body + body).
+    // VisitBufferDef now also visits AllocBuffer's buffer shape {x+1, x+1} in 
both init and body.
+    // Each adds 2 VarNode visits (x in each shape element), so 5 + 4 = 9.
+    TVM_FFI_ICHECK_EQ(v.count, 9);
   }
 }
 
@@ -221,8 +226,9 @@ TEST(IRF, StmtMutator) {
     auto* arrptr = arr.get();
     arr.MutateByApply([&](Stmt s) { return v(std::move(s)); });
     TVM_FFI_ICHECK(arr.get() == arrptr);
-    // buffer is not mutated (AllocBuffer mutator only visits body)
-    TVM_FFI_ICHECK(arr[0].as<AllocBufferNode>()->buffer.get() == bufptr);
+    // buffer IS mutated now (AllocBuffer mutator visits buffer shape via 
VisitBufferDef)
+    // shape was {1, x+1}, mutator transforms x+1 -> x, so buffer changes
+    TVM_FFI_ICHECK(arr[0].as<AllocBufferNode>()->buffer.get() != bufptr);
     // body is mutated: x+1 -> x
     TVM_FFI_ICHECK(!arr[0].as<AllocBufferNode>()->body.same_as(bref));
     
TVM_FFI_ICHECK(arr[0].as<AllocBufferNode>()->body.as<EvaluateNode>()->value.same_as(x));
@@ -270,7 +276,8 @@ TEST(IRF, StmtMutator) {
     body = v(std::move(body));
     // the seq get flattened
     TVM_FFI_ICHECK(body.as<SeqStmtNode>()->size() == 3);
-    
TVM_FFI_ICHECK(body.as<SeqStmtNode>()->seq[0].as<AllocBufferNode>()->buffer.get()
 == bufptr);
+    // buffer is now mutated (shape x+1 -> x via VisitBufferDef)
+    
TVM_FFI_ICHECK(body.as<SeqStmtNode>()->seq[0].as<AllocBufferNode>()->buffer.get()
 != bufptr);
     TVM_FFI_ICHECK(body.as<SeqStmtNode>()->seq[1].get() == ref2);
   }
 
@@ -333,36 +340,43 @@ TEST(IRF, Substitute) {
   using namespace tvm::tir;
   DataType dtype = DataType::Float(32);
   Var x("x", PointerType(PrimType(dtype), ""));
-  auto fmaketest = [&]() {
-    Buffer buffer{/*data=*/x,
+  Var n("n", DataType::Int(32));
+
+  auto fmakebuffer = [&]() {
+    return Buffer{/*data=*/x,
                   /*dtype=*/DataType::Float(32),
-                  /*shape=*/{},
+                  /*shape=*/{n},
                   /*strides=*/{},
                   /*elem_offset=*/NullValue<PrimExpr>(),
                   /*name=*/"buf",
                   /*data_alignment=*/1,
                   /*offset_factor=*/1,
                   /*buffer_type=*/BufferType::kDefault};
-    return BufferLoad(buffer, {});
   };
 
   {
-    // test substitute buffer var
+    // test substitute buffer data var and shape var via DeclBuffer
     Var y = x.copy_with_suffix("subst");
-    BufferLoad buffer_load = fmaketest();
+    Var m("m", DataType::Int(32));
+    Buffer buffer = fmakebuffer();
+    Stmt store = BufferStore(buffer, FloatImm(dtype, 0), 
{IntImm(DataType::Int(32), 0)});
+    Stmt decl = DeclBuffer(buffer, store);
     auto f_subst = [&](const Var& var) -> ffi::Optional<PrimExpr> {
-      if (var.same_as(x)) {
-        return y;
-      }
+      if (var.same_as(x)) return y;
+      if (var.same_as(n)) return m;
       return std::nullopt;
     };
-    BufferLoad new_buffer_load = Downcast<BufferLoad>(Substitute(buffer_load, 
f_subst));
-    TVM_FFI_ICHECK(new_buffer_load->buffer->data.same_as(y));
+    Stmt new_decl = Substitute(decl, f_subst);
+    auto* decl_node = new_decl.as<DeclBufferNode>();
+    TVM_FFI_ICHECK(decl_node != nullptr);
+    TVM_FFI_ICHECK(decl_node->buffer->data.same_as(y));
+    TVM_FFI_ICHECK(decl_node->buffer->shape[0].same_as(m));
   }
 
   {
-    // test identity substitution
-    PrimExpr expr = fmaketest();
+    // test identity substitution on expression
+    Buffer buffer = fmakebuffer();
+    PrimExpr expr = BufferLoad(buffer, {IntImm(DataType::Int(32), 0)});
     auto f_subst = [&](const Var& var) -> ffi::Optional<PrimExpr> { return 
var; };
     PrimExpr new_expr = Substitute(expr, f_subst);
     // the expression is not changed
diff --git a/tests/python/tir-analysis/test_tir_analysis_undefined_vars.py 
b/tests/python/tir-analysis/test_tir_analysis_undefined_vars.py
new file mode 100644
index 0000000000..3fd17830fd
--- /dev/null
+++ b/tests/python/tir-analysis/test_tir_analysis_undefined_vars.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for tir.analysis.undefined_vars (VarUseDefAnalyzer)."""
+
+import tvm
+import tvm.testing
+from tvm import tir
+
+
+def test_decl_buffer_data_is_use():
+    """DeclBuffer's data var should be reported as undefined (USE), not 
defined.
+
+    When UndefinedVars encounters a DeclBuffer, the data pointer references
+    an existing variable from the enclosing scope.  It must appear in the
+    undefined list so that callers (e.g., CreateComputeScope) capture it.
+    """
+    n = tir.SizeVar("n", "int32")
+    from tvm.ir import PointerType, PrimType
+
+    data_ptr = tir.Var("buf_data", PointerType(PrimType("float32")))
+    buf = tir.decl_buffer((n,), "float32", "buf", data=data_ptr)
+
+    body = tir.Evaluate(tir.BufferLoad(buf, [0]))
+    decl = tir.DeclBuffer(buf, body)
+
+    undef = tvm.tir.analysis.undefined_vars(decl, [])
+    undef_names = {v.name for v in undef}
+    # data_ptr must be undefined (it comes from outside the DeclBuffer)
+    assert "buf_data" in undef_names, f"Expected buf_data in undefined vars, 
got {undef_names}"
+
+
+def test_decl_buffer_elem_offset_is_use():
+    """DeclBuffer's elem_offset var should be reported as undefined (USE).
+
+    After FlattenBuffer, DeclBuffer nodes carry elem_offset vars from
+    match_buffer entries.  These must appear in the undefined list.
+    """
+    from tvm.ir import PointerType, PrimType
+
+    n = tir.SizeVar("n", "int32")
+    data_ptr = tir.Var("buf_data", PointerType(PrimType("float32")))
+    elem_off = tir.Var("buf_elem_offset", "int32")
+    buf = tir.decl_buffer((n,), "float32", "buf", data=data_ptr, 
elem_offset=elem_off)
+
+    body = tir.Evaluate(tir.BufferLoad(buf, [0]))
+    decl = tir.DeclBuffer(buf, body)
+
+    undef = tvm.tir.analysis.undefined_vars(decl, [])
+    undef_names = {v.name for v in undef}
+    assert "buf_data" in undef_names, f"Expected buf_data in undefined vars, 
got {undef_names}"
+    assert "buf_elem_offset" in undef_names, (
+        f"Expected buf_elem_offset in undefined vars, got {undef_names}"
+    )
+
+
+def test_alloc_buffer_data_is_def():
+    """AllocBuffer's data var should NOT be reported as undefined (it's a DEF).
+
+    AllocBuffer allocates new storage — the data pointer is a new definition,
+    not a reference to an external variable.
+    """
+    n = tir.SizeVar("n", "int32")
+    buf = tir.decl_buffer((n,), "float32", "buf")
+
+    body = tir.Evaluate(tir.BufferLoad(buf, [0]))
+    alloc = tir.AllocBuffer(buf, body)
+
+    undef = tvm.tir.analysis.undefined_vars(alloc, [])
+    undef_names = {v.name for v in undef}
+    # data should NOT be undefined — AllocBuffer defines it
+    assert buf.data.name not in undef_names, (
+        f"AllocBuffer data should be defined, but found {buf.data.name} in 
{undef_names}"
+    )
+    # shape var n should be undefined (comes from enclosing scope)
+    assert "n" in undef_names, f"Expected shape var 'n' in undefined vars, got 
{undef_names}"
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/tir-transform/test_tir_transform_simplify.py 
b/tests/python/tir-transform/test_tir_transform_simplify.py
index 3f73ed8e16..46e094acfa 100644
--- a/tests/python/tir-transform/test_tir_transform_simplify.py
+++ b/tests/python/tir-transform/test_tir_transform_simplify.py
@@ -1819,7 +1819,7 @@ def test_simplify_trivial_let_buffer_var():
 
 
 def test_simplify_trivial_let_elem_offset():
-    """A LetStmt used in a buffer definition should be retained"""
+    """A LetStmt used in a buffer definition should be retained, buffer fields 
unchanged"""
 
     @T.prim_func(private=True)
     def before(A_ptr: T.handle("float32"), A_offset: T.int32):
@@ -1827,14 +1827,18 @@ def test_simplify_trivial_let_elem_offset():
         A = T.decl_buffer(1, "float32", elem_offset=A_offset_redef, data=A_ptr)
         A[0] = 42.0
 
-    expected = before
+    @T.prim_func(private=True)
+    def expected(A_ptr: T.handle("float32"), A_offset: T.int32):
+        A_offset_redef = A_offset
+        A = T.decl_buffer(1, "float32", elem_offset=A_offset_redef, data=A_ptr)
+        A[0] = 42.0
 
     after = _apply_simplify(before)
     tvm.ir.assert_structural_equal(after, expected)
 
 
 def test_simplify_trivial_let_shape():
-    """A LetStmt used in a buffer definition should be retained"""
+    """A LetStmt used in a buffer definition should be retained, buffer fields 
unchanged"""
 
     @T.prim_func(private=True)
     def before(A_ptr: T.handle("float32"), A_size: T.int32):
@@ -1842,14 +1846,18 @@ def test_simplify_trivial_let_shape():
         A = T.decl_buffer([A_size_redef], "float32", data=A_ptr)
         A[0] = 42.0
 
-    expected = before
+    @T.prim_func(private=True)
+    def expected(A_ptr: T.handle("float32"), A_size: T.int32):
+        A_size_redef = A_size
+        A = T.decl_buffer([A_size_redef], "float32", data=A_ptr)
+        A[0] = 42.0
 
     after = _apply_simplify(before)
     tvm.ir.assert_structural_equal(after, expected)
 
 
 def test_simplify_trivial_let_stride():
-    """A LetStmt used in a buffer definition should be retained"""
+    """A LetStmt used in a buffer definition should be retained, buffer fields 
unchanged"""
 
     @T.prim_func(private=True)
     def before(A_ptr: T.handle("float32"), A_stride: T.int32):
@@ -1857,12 +1865,37 @@ def test_simplify_trivial_let_stride():
         A = T.decl_buffer(1, "float32", strides=[A_stride_redef], data=A_ptr)
         A[0] = 42.0
 
-    expected = before
+    @T.prim_func(private=True)
+    def expected(A_ptr: T.handle("float32"), A_stride: T.int32):
+        A_stride_redef = A_stride
+        A = T.decl_buffer(1, "float32", strides=[A_stride_redef], data=A_ptr)
+        A[0] = 42.0
 
     after = _apply_simplify(before)
     tvm.ir.assert_structural_equal(after, expected)
 
 
+def test_simplify_buffer_identity_well_formed():
+    """Regression: Simplify must not diverge buffer identity between 
DeclBuffer and BufferLoad.
+
+    The simplifier's VisitExpr calls analyzer_->Simplify() directly, bypassing
+    normal ExprMutator dispatch.  If VisitBufferDef remaps a buffer at a 
DeclBuffer
+    site (e.g. inlining n_val -> n in the shape), BufferLoad inside a 
BufferStore
+    value would NOT pick up the remap because VisitBufferUse is never called.
+    This causes DeclBuffer/BufferLoad buffer identity divergence.
+    """
+
+    @T.prim_func(private=True)
+    def before(A_ptr: T.handle("float32"), B_ptr: T.handle("float32"), n: 
T.int32):
+        n_val = n
+        A = T.decl_buffer([n_val], "float32", data=A_ptr)
+        B = T.decl_buffer([n_val], "float32", data=B_ptr)
+        B[0] = A[0]
+
+    after = _apply_simplify(before)
+    tvm.tir.analysis.verify_well_formed(after)
+
+
 def test_buffer_shape_constraint():
     @I.ir_module(check_well_formed=False)
     class Before:

Reply via email to