This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 30b34d2521 [TIR] More flexible buffer compaction (#14021)
30b34d2521 is described below

commit 30b34d2521dd8b3fa99b9bd0297114cc05698e02
Author: wrongtest <[email protected]>
AuthorDate: Sat Apr 29 05:06:13 2023 +0800

    [TIR] More flexible buffer compaction (#14021)
    
    Hi there, the change want to enforce the power and flexiblity of 
`CompactBufferAllocation` pass in two aspects:
    1. Free of pass order
        - It could work on both s-tir (with opaque blocks) and lowered tir. For 
example, now one could be able to invoke `LoopPartition` and then make 
partitioned tile aware buffer compactions.
        - We test existing cases to ensure that
          `(LowerOpaqueBlock . CompactBufferAllocation) (mod) == 
(CompactBufferAllocation . LowerOpaqueBlock) (mod)`
    
    2. Allow "non-strict" compaction
        - Add an option `is_strict` defaults to `True`, to denote that during 
compaction we should respect the original buffer shape bound. Thus the 
compacted buffer region never exceed the original.
        - If set to `False`, the "compacted" shape is totally determined to 
cover buffer region accesses. Thus it may become larger than the original 
shape. This change the original semantic for out-of-bound accesses but may be 
helpful in certain usages.
        - If loop domain changed (eg, align the loop dim or remove the extra 
predicate), the accessed buffer region may grow, the pass could provide a 
fallback implementation to adapt the buffer regions.
    
    About implementation issues:
    - To achieve [1]
        -  Buffer decl point:
            -  (s-tir) `T.alloc_buffer` in block form is handled without any 
change.
            -  (lowered) `T.decl_buffer` is newly handled. We assume it is at 
the proper position to dominate all accesses.
        -  Predicates
            - Predicates in `T.where` and `IfThenElse`, `T.if_then_else` are 
handled uniformlly now. We would try simply resolve the predicate to update 
loop domain as before. But on failures, now we keep them into pending stack.
            - The visit logic is very alike `IRVisitorWithAnalyzer`'s, but 
since  `arith::Analyzer` and `arith::IntSet` are independent components now, 
the change do not introduce `IRVisitorWithAnalyzer`'.
        - Buffer accesses
            - There is no difference between s-tir and lowered form. If there 
are pending predicates on access point, always try affine iter analysis to get 
more tight relaxed region.
        - Thread binding
            - To work on lowered form, `attr::thread_extent` and 
`attr::virtual_thread` are handled to record neccesary info for thread relaxion.
        - Buffer aliasing
            - No change to `T.match_buffer` handling, and we explicitly disable 
compaction to alised buffers in lowered form.
        - Dim alignment
            - We utilize annotation field in `T.allocate` and preserve 
`attr::buffer_dim_align` in `LowerOpaqueBlock` pass when it convert 
`T.alloc_buffer` to `T.allocate`. Thus the compaction could collect the dim 
alignment information in lowered form.
    
    - To achieve [2]
        - It is much direct. `SimplifyAndNarrowBufferRegionFromNDIntSet` would 
not intersect accessed region with original shape if the `is_strict` option is 
overrided to false.
---
 include/tvm/tir/transform.h                        |    5 +-
 python/tvm/tir/transform/transform.py              |   10 +-
 src/tir/analysis/block_access_region_detector.cc   |   10 +-
 src/tir/schedule/primitive.h                       |    5 +-
 src/tir/schedule/primitive/block_annotate.cc       |    1 +
 src/tir/transforms/compact_buffer_region.cc        |  487 ++++--
 src/tir/transforms/ir_utils.cc                     |  101 +-
 src/tir/transforms/ir_utils.h                      |   44 +-
 src/tir/transforms/lower_opaque_block.cc           |   34 +-
 tests/python/unittest/test_tir_buffer.py           |   23 -
 .../test_tir_transform_compact_buffer_region.py    | 1730 +++++++++++---------
 11 files changed, 1470 insertions(+), 980 deletions(-)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 35aa392db2..8dee176277 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -456,10 +456,11 @@ TVM_DLL Pass ConvertBlocksToOpaque();
  *
  *  \endcode
  *
- *
+ * \param is_strict ensure the compacted shape always smaller than the 
original shape.
+ *   otherwise it allows to grow the shape to match actual accessed buffer 
regions.
  * \return The pass.
  */
-TVM_DLL Pass CompactBufferAllocation();
+TVM_DLL Pass CompactBufferAllocation(bool is_strict = true);
 
 /*!
  * This pass legalizes packed calls by wrapping their arguments into TVMValues
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index 1df2ac76b5..0437f1a887 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -747,7 +747,7 @@ def ConvertBlocksToOpaque():
     return _ffi_api.ConvertBlocksToOpaque()  # type: ignore
 
 
-def CompactBufferAllocation():
+def CompactBufferAllocation(is_strict: bool = True):
     """Compact the buffer access region. by removing the buffer regions
     that are not accessed, i.e. narrowing the buffer shape and adjust
     the access region if necessary.
@@ -783,13 +783,19 @@ def CompactBufferAllocation():
                 for j in range(0, 16):
                     C[i, j] = B[0, j] + 1
 
+    Parameters
+    ----------
+    is_strict : bool
+        Ensure the compacted shape to be always smaller than the original 
shape.
+        Otherwise it allows to grow the shape to match actual accessed buffer 
regions.
+
     Returns
     -------
     fpass : tvm.transform.Pass
         The result pass
 
     """
-    return _ffi_api.CompactBufferAllocation()  # type: ignore
+    return _ffi_api.CompactBufferAllocation(is_strict)  # type: ignore
 
 
 def LowerMatchBuffer():
diff --git a/src/tir/analysis/block_access_region_detector.cc 
b/src/tir/analysis/block_access_region_detector.cc
index 057cec475d..a15cecabdd 100644
--- a/src/tir/analysis/block_access_region_detector.cc
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -60,6 +60,8 @@ class BlockReadWriteDetector : public StmtExprVisitor {
   std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
   /*! \brief Extra iteration range hint for free vars */
   std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
+  /*! \brief Unresolved conditions within current scope. */
+  std::vector<PrimExpr> pending_conditions_;
   /*! \brief The buffers that the current block reads */
   std::vector<Buffer> read_buffers_;
   /*! \brief The buffers that the current block writes */
@@ -164,12 +166,12 @@ void BlockReadWriteDetector::VisitStmt_(const 
IfThenElseNode* op) {
   VisitExpr(op->condition);
   {
     // Visit then branch
-    With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, 
true);
+    With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, 
&pending_conditions_);
     StmtExprVisitor::VisitStmt(op->then_case);
   }
   if (op->else_case) {
     // Visit else branch
-    With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, 
false);
+    With<ConditionalBoundsContext> ctx(!op->condition, &dom_map_, &hint_map_, 
&pending_conditions_);
     StmtExprVisitor::VisitStmt(op->else_case.value());
   }
 }
@@ -207,12 +209,12 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* 
op) {
     VisitExpr(op->args[0]);
     {
       // Visit then branch
-      With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, 
true);
+      With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, 
&pending_conditions_);
       StmtExprVisitor::VisitExpr(op->args[1]);
     }
     {
       // Visit else branch
-      With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, 
false);
+      With<ConditionalBoundsContext> ctx(!op->args[0], &dom_map_, &hint_map_, 
&pending_conditions_);
       StmtExprVisitor::VisitExpr(op->args[2]);
     }
     return;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 5f5591ac45..78d1cab05c 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -481,10 +481,7 @@ TVM_DLL StmtSRef DecomposeReduction(ScheduleState self, 
const StmtSRef& block_sr
  */
 TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int 
factor_axis);
 /******** Schedule: Block annotation ********/
-/*! \brief The quad used by StorageAlign for (buffer_idx, axis, factor, 
offset) */
-using StorageAlignTuple = Array<Integer>;
-/*! \brief A list of StorageAlignTuple, used by StorageAlign */
-using StorageAlignAnnotation = Array<StorageAlignTuple>;
+
 /*!
  * \brief Set alignment requirement for specific dimension such that
  *        stride[axis] == k * factor + offset for some k. This is useful to 
set memory layout for
diff --git a/src/tir/schedule/primitive/block_annotate.cc 
b/src/tir/schedule/primitive/block_annotate.cc
index 3f1789b3d6..917e37c9a7 100644
--- a/src/tir/schedule/primitive/block_annotate.cc
+++ b/src/tir/schedule/primitive/block_annotate.cc
@@ -18,6 +18,7 @@
  */
 #include <tvm/tir/expr.h>
 
+#include "../../transforms/ir_utils.h"
 #include "../utils.h"
 
 namespace tvm {
diff --git a/src/tir/transforms/compact_buffer_region.cc 
b/src/tir/transforms/compact_buffer_region.cc
index a047191f16..6c6eb169a5 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -28,6 +28,7 @@
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
+#include <numeric>
 #include <stack>
 
 #include "../../support/arena.h"
@@ -41,51 +42,6 @@ namespace tir {
 
 using support::NDIntSet;
 
-/*!
- * \brief simplify and return the region collected by NDIntSet. return the 
original
- * buffer shape if the int_set is empty.
- */
-Region SimplifyAndNarrowBufferRegionFromNDIntSet(
-    const NDIntSet& nd_int_set, const Array<PrimExpr>& original_shape, 
arith::Analyzer* analyzer,
-    const std::vector<const ForNode*>& ancestor_loops) {
-  Array<Range> result;
-  result.reserve(nd_int_set.size());
-  for (size_t i = 0; i < nd_int_set.size(); ++i) {
-    const arith::IntSet& int_set = nd_int_set[i];
-    Range range = int_set.CoverRange(Range(/*begin=*/0, 
/*end=*/original_shape[i]));
-    PrimExpr min = analyzer->Simplify(tvm::max(0, range->min));
-    PrimExpr extent = range->extent;
-
-    // Apply stronger symbolic proof to help us remove symbolic min here.
-    if (!analyzer->CanProveLessEqualThanSymbolicShapeValue(range->extent, 
original_shape[i])) {
-      extent = tvm::min(original_shape[i], range->extent);
-    }
-
-    extent = analyzer->Simplify(extent);
-
-    // Check the buffer region is not loop dependent, since loop dependent
-    // allocation is not supported yet.
-    auto is_loop_var = [&ancestor_loops](const VarNode* v) {
-      return std::any_of(ancestor_loops.begin(), ancestor_loops.end(),
-                         [v](const ForNode* n) { return n->loop_var.get() == 
v; });
-    };
-    if (UsesVar(extent, is_loop_var)) {
-      // try estimate a constant upperbound on region's extent
-      int64_t upperbound = analyzer->const_int_bound(extent)->max_value;
-      if (upperbound != arith::ConstIntBound::kPosInf) {
-        extent = make_const(extent->dtype, upperbound);
-      } else {
-        // or else we have to fallback to full region
-        min = make_zero(original_shape[i]->dtype);
-        extent = original_shape[i];
-      }
-    }
-
-    result.push_back(Range::FromMinExtent(min, extent));
-  }
-  return result;
-}
-
 /*! \brief a more constrained bound estimate for n-dimentional int set */
 NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
                       const std::unordered_map<const VarNode*, arith::IntSet>& 
dom_map,
@@ -103,6 +59,44 @@ NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
   return support::NDIntSetEval(support::NDIntSetFromRegion(region), dom_map);
 }
 
+/*!
+ * \brief Collect buffer aliasing information.
+ */
+class Var2BufferCollector : public StmtExprVisitor {
+ public:
+  /*! \brief Map the buffer var to all aliased buffers. */
+  std::unordered_map<Var, std::unordered_set<Buffer, ObjectPtrHash, 
ObjectPtrEqual>, ObjectPtrHash,
+                     ObjectPtrEqual>
+      var2buffer_;
+
+ private:
+  void VisitStmt_(const BufferStoreNode* op) final {
+    var2buffer_[op->buffer->data].insert(op->buffer);
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) final {
+    var2buffer_[op->buffer->data].insert(op->buffer);
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    for (const Buffer& buffer : op->alloc_buffers) {
+      var2buffer_[buffer->data].insert(buffer);
+    }
+    for (const MatchBufferRegion& region : op->match_buffers) {
+      var2buffer_[region->buffer->data].insert(region->buffer);
+      var2buffer_[region->source->buffer->data].insert(region->source->buffer);
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const DeclBufferNode* op) final {
+    var2buffer_[op->buffer->data].insert(op->buffer);
+    StmtExprVisitor::VisitStmt_(op);
+  }
+};
+
 /*!
  * \brief Collect the access region of each buffer.
  * \note The param buffer regions will not be collected.
@@ -110,10 +104,17 @@ NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
 class BufferAccessRegionCollector : public StmtExprVisitor {
  public:
   static std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> 
Collect(
-      const PrimFunc& f) {
-    BufferAccessRegionCollector collector;
-    collector(f->body);
-    return std::move(collector.buffer_access_region_);
+      const PrimFunc& f, bool collect_inbound) {
+    BufferAccessRegionCollector region_collector(collect_inbound);
+
+    // collect buffer var to aliased buffer mapping
+    Var2BufferCollector var2buffer_collector;
+    var2buffer_collector(f->body);
+    std::swap(region_collector.var2buffer_, var2buffer_collector.var2buffer_);
+
+    // collect buffer access regions
+    region_collector(f->body);
+    return std::move(region_collector.buffer_access_region_);
   }
 
  private:
@@ -127,7 +128,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
         : buffer(buffer), accessed_region(region) {}
   };
 
-  BufferAccessRegionCollector() = default;
+  explicit BufferAccessRegionCollector(bool collect_inbound) : 
collect_inbound_(collect_inbound) {}
 
   /**************** Visitor overload ****************/
 
@@ -138,18 +139,23 @@ class BufferAccessRegionCollector : public 
StmtExprVisitor {
 
   void VisitExpr_(const BufferLoadNode* op) final {
     VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices));
+    StmtExprVisitor::VisitExpr_(op);
   }
 
   void VisitExpr_(const VarNode* op) final { VisitBufferVar(GetRef<Var>(op)); }
 
   void VisitStmt_(const ForNode* op) final {
-    ancestor_loops_.push_back(op);
     Range loop_range = Range::FromMinExtent(op->min, op->extent);
+    IterVar iter = op->kind == ForKind::kThreadBinding
+                       ? IterVar(Range(), op->loop_var, 
IterVarType::kThreadIndex,
+                                 op->thread_binding.value()->thread_tag)
+                       : IterVar(Range(), op->loop_var, IterVarType::kDataPar);
+    ancestor_iters_.push_back(iter);
     dom_analyzer_.Bind(op->loop_var, loop_range);
     dom_map_.emplace(op->loop_var.get(), arith::IntSet::FromRange(loop_range));
     StmtExprVisitor::VisitStmt_(op);
     dom_map_.erase(op->loop_var.get());
-    ancestor_loops_.pop_back();
+    ancestor_iters_.pop_back();
   }
 
   void VisitStmt_(const LetStmtNode* op) final {
@@ -181,12 +187,14 @@ class BufferAccessRegionCollector : public 
StmtExprVisitor {
     StmtExprVisitor::VisitExpr(op->condition);
     {
       // Visit then branch
-      With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, 
true);
+      With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_,
+                                         &pending_conditions_);
       StmtExprVisitor::VisitStmt(op->then_case);
     }
     if (op->else_case) {
       // Visit else branch
-      With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, 
false);
+      With<ConditionalBoundsContext> ctx(!op->condition, &dom_map_, &hint_map_,
+                                         &pending_conditions_);
       StmtExprVisitor::VisitStmt(op->else_case.value());
     }
   }
@@ -197,12 +205,14 @@ class BufferAccessRegionCollector : public 
StmtExprVisitor {
       StmtExprVisitor::VisitExpr(op->args[0]);
       {
         // Visit then branch
-        With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, 
true);
+        With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_,
+                                           &pending_conditions_);
         StmtExprVisitor::VisitExpr(op->args[1]);
       }
       {
         // Visit else branch
-        With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, 
false);
+        With<ConditionalBoundsContext> ctx(!op->args[0], &dom_map_, &hint_map_,
+                                           &pending_conditions_);
         StmtExprVisitor::VisitExpr(op->args[2]);
       }
       return;
@@ -227,9 +237,9 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
       auto& regions = access_annotations_[p.first];
       p.second.swap(regions);
     }
-    // Step 2. Record relax position of ancestor_loops_ into 
buffer_var_in_scope_
+    // Step 2. Record relax position of ancestor_loops_
     for (const Buffer& buffer : op->alloc_buffers) {
-      buffer_var_in_scope_.emplace(buffer->data, std::make_pair(buffer, 
ancestor_loops_.size()));
+      VisitBufferDef(buffer->data);
     }
     // Step 3. Visit match buffers
     for (const MatchBufferRegion& region : op->match_buffers) {
@@ -248,37 +258,77 @@ class BufferAccessRegionCollector : public 
StmtExprVisitor {
     }
     // Step 6. Update buffer_access_region_ from relaxed_accesses_ for inner 
buffers.
     for (const Buffer& buffer : op->alloc_buffers) {
-      auto it = relaxed_accesses_.find(buffer);
-      ICHECK(it != relaxed_accesses_.end())
-          << buffer << " is allocated but not accessed within block scope";
-      const NDIntSet& nd_int_set = it->second;
-      buffer_access_region_[buffer] = 
SimplifyAndNarrowBufferRegionFromNDIntSet(
-          nd_int_set, buffer->shape, &dom_analyzer_, ancestor_loops_);
+      ICHECK_EQ(var2buffer_[buffer->data].size(), 1)
+          << "Block allocation buffer shoud not be alised";
+      SimplifyAndNarrowBufferRegionFromNDIntSet(buffer);
     }
   }
 
   void VisitStmt_(const BlockRealizeNode* op) final {
-    PrimExpr cur_predicate = predicate_in_scope;
-    predicate_in_scope = op->predicate;
+    With<ConditionalBoundsContext> ctx(op->predicate, &dom_map_, &hint_map_, 
&pending_conditions_);
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const AllocateNode* op) final {
+    auto it = var2buffer_.find(op->buffer_var);
+
+    // Do not make compaction when the buffer def and
+    // the allocation is not one-to-one with the same dtype.
+    if (it == var2buffer_.end() || it->second.size() > 1) {
+      return StmtExprVisitor::VisitStmt_(op);
+    }
+    const Buffer& buffer = *it->second.begin();
+    if (buffer->dtype != op->dtype) {
+      return StmtExprVisitor::VisitStmt_(op);
+    }
+
+    // Step 0. Record relax position of ancestor_loops_
+    VisitBufferDef(op->buffer_var);
+    // Step 1. Visit block body recursively
+    StmtExprVisitor::VisitStmt(op->body);
+    // Step 2. Update buffer_access_region_ from relaxed_accesses_ for inner 
buffers.
+    SimplifyAndNarrowBufferRegionFromNDIntSet(buffer);
+  }
+
+  void VisitStmt_(const AttrStmtNode* op) final {
+    if (op->attr_key == attr::thread_extent || op->attr_key == 
attr::virtual_thread) {
+      IterVar iter = Downcast<IterVar>(op->node);
+      ancestor_iters_.push_back(iter);
+      Range dom = iter->dom;
+      if (!dom.defined()) {  // dom is empty for legacy te schedule
+        dom = Range::FromMinExtent(0, op->value);
+      }
+      dom_analyzer_.Bind(iter->var, dom);
+      dom_map_.emplace(iter->var.get(), arith::IntSet::FromRange(dom));
+      StmtExprVisitor::VisitStmt_(op);
+      dom_map_.erase(iter->var.get());
+      ancestor_iters_.pop_back();
+      return;
+    }
     StmtExprVisitor::VisitStmt_(op);
-    predicate_in_scope = cur_predicate;
   }
 
   /**************** Helper functions ****************/
 
+  /*! \brief Record information on the buffer defining point. */
+  void VisitBufferDef(const Var& buffer_data) {
+    auto it = buffer_scope_depth_.find(buffer_data);
+    ICHECK(it == buffer_scope_depth_.end()) << buffer_data << " has duplicate 
definitions";
+    buffer_scope_depth_.insert(it, {buffer_data, ancestor_iters_.size()});
+  }
+
   void VisitBufferAccess(const BufferRegion& buffer_region) {
-    const BufferNode* buffer = buffer_region->buffer.get();
-    auto it = buffer_var_in_scope_.find(buffer->data);
-    if (it != buffer_var_in_scope_.end()) {
-      const Buffer& buffer = it->second.first;
-      size_t n_ancestor_loops = it->second.second;
+    const Buffer& buffer = buffer_region->buffer;
+    auto it = buffer_scope_depth_.find(buffer->data);
+    if (it != buffer_scope_depth_.end()) {
+      size_t n_ancestor_loops = it->second;
       // Step 1. Stop ancestor loop vars out of the allocation block from
       // being relaxed unless NeedRelaxThread() is true.
       std::vector<arith::IntSet> non_relaxed(n_ancestor_loops);
       for (size_t i = 0; i < n_ancestor_loops; ++i) {
-        const ForNode* loop = ancestor_loops_[i];
-        const VarNode* v = loop->loop_var.get();
-        if (NeedRelaxThread(GetRef<For>(loop), 
runtime::StorageScope::Create(buffer.scope()))) {
+        const IterVar& iter = ancestor_iters_[i];
+        const VarNode* v = iter->var.get();
+        if (NeedRelaxThread(iter, 
runtime::StorageScope::Create(buffer.scope()))) {
           continue;
         }
         auto dom_it = dom_map_.find(v);
@@ -288,12 +338,21 @@ class BufferAccessRegionCollector : public 
StmtExprVisitor {
         dom_map_.erase(dom_it);
       }
       // Step 2. Relax the access region
+      auto normalize_pred = [](const PrimExpr& pred) {
+        if (pred->dtype.is_bool()) return pred;
+        return pred != make_zero(pred->dtype);
+      };
+      PrimExpr predicate = dom_analyzer_.Simplify(
+          std::accumulate(pending_conditions_.begin(), 
pending_conditions_.end(), const_true(),
+                          [normalize_pred](const PrimExpr& x, const PrimExpr& 
y) {
+                            return normalize_pred(x) && normalize_pred(y);
+                          }));
       NDIntSet nd_int_set =
-          NDIntSetEval(buffer_region->region, predicate_in_scope, dom_map_, 
&dom_analyzer_);
+          NDIntSetEval(buffer_region->region, predicate, dom_map_, 
&dom_analyzer_);
 
       // Step 3. Restore the non-relaxed ancestor loops domain
       for (size_t i = 0; i < n_ancestor_loops; ++i) {
-        const VarNode* v = ancestor_loops_[i]->loop_var.get();
+        const VarNode* v = ancestor_iters_[i]->var.get();
         dom_map_.emplace(v, non_relaxed[i]);
       }
       // Step 4. Update relaxed_accesses_ dict
@@ -307,9 +366,11 @@ class BufferAccessRegionCollector : public StmtExprVisitor 
{
   }
 
   void VisitBufferVar(const Var& var) {
-    auto it = buffer_var_in_scope_.find(var);
-    if (it != buffer_var_in_scope_.end()) {
-      const Buffer& buffer = it->second.first;
+    auto it = var2buffer_.find(var);
+    if (it == var2buffer_.end()) {
+      return;
+    }
+    for (const Buffer& buffer : it->second) {
       auto annotation_it = access_annotations_.find(buffer);
       if (annotation_it != access_annotations_.end()) {
         // opaque buffer has explicit accessed region annotations
@@ -322,92 +383,135 @@ class BufferAccessRegionCollector : public 
StmtExprVisitor {
     }
   }
 
-  /*! \brief Check whether the thread binding loop should be relaxed with 
given storage scope. */
-  static bool NeedRelaxThread(const For& loop, const runtime::StorageScope& 
scope) {
-    if (loop->kind != ForKind::kThreadBinding) {
+  /*! \brief Check whether the thread binding iter should be relaxed with 
given storage scope. */
+  static bool NeedRelaxThread(const IterVar& iter, const 
runtime::StorageScope& scope) {
+    if (iter->iter_type != IterVarType::kThreadIndex) {
       return false;
     }
-    ICHECK(loop->thread_binding.defined());
-    const String& thread_tag = loop->thread_binding.value()->thread_tag;
+    ICHECK(iter->thread_tag.defined());
     // When there is warp memory
     // threadIdx.x must be set to be warp index.
-    return CanRelaxStorageUnderThread(scope, 
runtime::ThreadScope::Create(thread_tag));
+    return CanRelaxStorageUnderThread(scope, 
runtime::ThreadScope::Create((iter->thread_tag)));
+  }
+
+  /*!
+   * \brief simplify and narrow down the region collected by NDIntSet.
+   * Update the `relaxed_accesses_` dict. If `collect_inbound_` is true,
+   * the result region would never exceed the original buffer shape.
+   */
+  void SimplifyAndNarrowBufferRegionFromNDIntSet(const Buffer& buffer) {
+    auto it = relaxed_accesses_.find(buffer);
+    ICHECK(it != relaxed_accesses_.end())
+        << buffer << " is allocated but not accessed within block scope";
+
+    const Array<PrimExpr>& original_shape = buffer->shape;
+    const NDIntSet& nd_int_set = it->second;
+    Array<Range>& result_region = buffer_access_region_[buffer];
+    result_region.resize(nd_int_set.size());
+
+    for (size_t i = 0; i < nd_int_set.size(); ++i) {
+      const arith::IntSet& int_set = nd_int_set[i];
+      Range original =
+          Range(/*begin=*/make_zero(original_shape[i]->dtype), 
/*end=*/original_shape[i]);
+      Range range = int_set.CoverRange(original);
+      PrimExpr min, extent;
+      if (collect_inbound_) {
+        min = dom_analyzer_.Simplify(tvm::max(0, range->min));
+        extent = range->extent;
+        // Apply stronger symbolic proof to help us remove symbolic min here.
+        if (!dom_analyzer_.CanProveLessEqualThanSymbolicShapeValue(extent, 
original_shape[i])) {
+          extent = tvm::min(original_shape[i], range->extent);
+        }
+        extent = dom_analyzer_.Simplify(extent);
+      } else {
+        min = dom_analyzer_.Simplify(range->min);
+        extent = dom_analyzer_.Simplify(range->extent);
+      }
+
+      // We check the buffer extent is pure and not loop dependent, since loop 
dependent
+      // or data dependent allocation is not supported yet. Otherwise we should
+      // fallback to use original buffer shape.
+      if (SideEffect(extent) > CallEffectKind::kPure) {
+        result_region.Set(i, original);
+        continue;
+      }
+      auto is_loop_var = [this](const VarNode* v) {
+        return std::any_of(ancestor_iters_.begin(), ancestor_iters_.end(),
+                           [v](const IterVar& n) { return n->var.get() == v; 
});
+      };
+      if (UsesVar(extent, is_loop_var)) {
+        // try estimate a constant upperbound on region's extent
+        int64_t upperbound = dom_analyzer_.const_int_bound(extent)->max_value;
+        if (upperbound != arith::ConstIntBound::kPosInf) {
+          extent = make_const(extent->dtype, upperbound);
+        } else {
+          result_region.Set(i, original);
+          continue;
+        }
+      }
+      result_region.Set(i, Range::FromMinExtent(min, extent));
+    }
   }
 
   /**************** Class members ****************/
-  /*! \brief The loops from the current node up to the root. */
-  std::vector<const ForNode*> ancestor_loops_;
+  /*! \brief Only collect accessed region within original buffer shape bound. 
*/
+  bool collect_inbound_{true};
+
+  /*! \brief The iteration scopes from the current node up to the root. */
+  std::vector<IterVar> ancestor_iters_;
 
   /*!
-   * \brief The vars of the buffer allocated under the current block.
-   * Map each buffer var to (buffer_obj, n_ancester_loop) pair, where
-   * n_ancester_loop is the loop num out of the current block.
-   * Tancestor_loops_[0: n_ancester_loop] should not be relaxed when
+   * \brief Map each buffer var to the n_ancester_loop. which is the loop 
depth at the
+   * define point. ancestor_loops_[0: n_ancester_loop] should not be relaxed 
when
    * we evaluate this buffer's access regions.
    */
-  std::unordered_map<Var, std::pair<Buffer, size_t>, ObjectPtrHash, 
ObjectPtrEqual>
-      buffer_var_in_scope_;
-  /*! \brief The block predicate of current scope */
-  PrimExpr predicate_in_scope{true};
+  std::unordered_map<Var, size_t, ObjectPtrHash, ObjectPtrEqual> 
buffer_scope_depth_;
+
+  /*! \brief Map the buffer var to all aliased buffers. */
+  std::unordered_map<Var, std::unordered_set<Buffer, ObjectPtrHash, 
ObjectPtrEqual>, ObjectPtrHash,
+                     ObjectPtrEqual>
+      var2buffer_;
 
   /*! \brief The map from loop vars to their iter range. */
   std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
   /*! \brief Extra map from free vars to their iter range hints. */
   std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
+  /*! \brief Unresolved conditions within current scope. */
+  std::vector<PrimExpr> pending_conditions_;
   /*! \brief The analyzer aware of loop domains. */
   arith::Analyzer dom_analyzer_;
   /*! \brief The map from Buffer to it's relaxed access set. */
   std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> 
relaxed_accesses_;
-  /*! \brief The map from Buffer to it entire access region, used for 
returning. */
+
+  /*!
+   * \brief The map from Buffer to it entire access region, used for returning.
+   * The entire access region should get updated on the buffer's define point
+   * and we sanity check that every buffer is defined only once.
+   */
   std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> 
buffer_access_region_;
+
   /*! \brief The map from Buffer to it's access regions annotated by current 
block. */
   std::unordered_map<Buffer, std::vector<BufferRegion>, ObjectPtrHash, 
ObjectPtrEqual>
       access_annotations_;
 };
 
-/*! \brief Collect storage alignment information from block annotations. */
-class StorageAlignCollector : public StmtVisitor {
- public:
-  static std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, 
ObjectPtrEqual> Collect(
-      const PrimFunc& f) {
-    StorageAlignCollector collector;
-    collector(f->body);
-    return std::move(collector.storage_align_);
-  }
-
- private:
-  void VisitStmt_(const BlockNode* op) final {
-    auto it = op->annotations.find(attr::buffer_dim_align);
-    if (it != op->annotations.end()) {
-      auto storage_align_annotation = 
Downcast<StorageAlignAnnotation>((*it).second);
-      for (const auto& storage_align_tuple : storage_align_annotation) {
-        int buffer_index = storage_align_tuple[0]->value;
-        const Buffer& buffer = op->writes[buffer_index]->buffer;
-        storage_align_[buffer].push_back(storage_align_tuple);
-      }
-    }
-    StmtVisitor::VisitStmt_(op);
-  }
-
-  /*! \brief The map from Buffer to its storage alignment information. */
-  std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, 
ObjectPtrEqual> storage_align_;
-};
-
 /*! \brief Reallocate the buffers with minimal region. */
 class BufferCompactor : public StmtExprMutator {
  public:
   static Stmt Compact(
       const PrimFunc& f,
       const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>& 
regions,
-      const std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, 
ObjectPtrEqual>&
+      const std::unordered_map<Var, StorageAlignAnnotation, ObjectPtrHash, 
ObjectPtrEqual>&
           storage_align) {
-    std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> 
buffer_info;
-
+    // collect buffer allocation info for no-alias buffers
+    std::unordered_map<Var, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> 
buffer_info;
     for (const auto& kv : regions) {
       const Buffer& buffer = kv.first;
+
+      // set dim alignment info
       Region region = kv.second;
-      BufferAllocInfo buffer_alloc_info(std::move(region));
-      auto it = storage_align.find(buffer);
+      BufferAllocInfo alloc_info;
+      auto it = storage_align.find(buffer->data);
       if (it != storage_align.end()) {
         std::vector<DimAlignInfo> dim_aligns(buffer->shape.size());
         for (const StorageAlignTuple& dim_align : (*it).second) {
@@ -417,9 +521,33 @@ class BufferCompactor : public StmtExprMutator {
           int offset = dim_align[3]->value;
           dim_aligns.at(dim) = {factor, offset};
         }
-        buffer_alloc_info.dim_aligns = std::move(dim_aligns);
+        alloc_info.dim_aligns = std::move(dim_aligns);
+      }
+
+      // prepare new buffer
+      Array<PrimExpr> shape = region.Map([](const Range& range) { return 
range->extent; });
+      Array<PrimExpr> strides;
+      if (alloc_info.dim_aligns.size()) {
+        ICHECK(alloc_info.dim_aligns.size() == shape.size());
+        strides.resize(shape.size());
+        PrimExpr stride = make_const(shape[0].dtype(), 1);
+        for (size_t i = shape.size(); i != 0; --i) {
+          size_t dim = i - 1;
+          if (alloc_info.dim_aligns[dim].align_factor != 0) {
+            PrimExpr factor = make_const(stride.dtype(), 
alloc_info.dim_aligns[dim].align_factor);
+            PrimExpr offset = make_const(stride.dtype(), 
alloc_info.dim_aligns[dim].align_offset);
+            stride = stride + indexmod(factor + offset - indexmod(stride, 
factor), factor);
+          }
+          strides.Set(dim, stride);
+          stride = stride * shape[dim];
+        }
       }
-      buffer_info.emplace(buffer, std::move(buffer_alloc_info));
+      ObjectPtr<BufferNode> n = make_object<BufferNode>(*buffer.get());
+      n->shape = std::move(shape);
+      n->strides = std::move(strides);
+      alloc_info.new_buffer = Buffer(std::move(n));
+      alloc_info.region = region;
+      buffer_info.emplace(buffer->data, std::move(alloc_info));
     }
     BufferCompactor compactor(std::move(buffer_info));
     Stmt stmt = compactor(f->body);
@@ -445,12 +573,10 @@ class BufferCompactor : public StmtExprMutator {
      * \note The value if NullOpt if the buffer do not need reallocate (e.g 
parameter buffer).
      */
     Buffer new_buffer;
-
-    explicit BufferAllocInfo(Region region) : region(std::move(region)) {}
   };
 
   explicit BufferCompactor(
-      std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, 
ObjectPtrEqual> buffer_info)
+      std::unordered_map<Var, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> 
buffer_info)
       : buffer_info_(std::move(buffer_info)) {}
 
   Stmt VisitStmt_(const BufferStoreNode* _op) final {
@@ -471,7 +597,8 @@ class BufferCompactor : public StmtExprMutator {
     // Step 0. Check there is no Init part.
     ICHECK(!op->init.defined());
     // Step 1. Reallocate and rewrite alloc_buffers, also update 
BufferAllocInfo.
-    Array<Buffer> alloc_buffers = RewriteAllocBuffer(op->alloc_buffers);
+    Array<Buffer> alloc_buffers =
+        op->alloc_buffers.Map([this](const Buffer& buf) { return 
RewriteAllocBuffer(buf); });
     // Step 2. Recursively rewrite BufferLoad/BufferStore.
     Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
     // Step 3. Update block signature.
@@ -483,47 +610,45 @@ class BufferCompactor : public StmtExprMutator {
     return std::move(block);
   }
 
-  Array<Buffer> RewriteAllocBuffer(const Array<Buffer>& buffers) {
-    Array<Buffer> result;
-    result.reserve(buffers.size());
-    for (const Buffer& buffer : buffers) {
-      auto it = buffer_info_.find(buffer);
-      ICHECK(it != buffer_info_.end());
-      BufferAllocInfo& info = it->second;
-      Array<PrimExpr> shape;
-      shape.reserve(info.region.size());
-      for (const Range& range : info.region) {
-        shape.push_back(range->extent);
-      }
-      Array<PrimExpr> strides;
-      if (info.dim_aligns.size()) {
-        ICHECK(info.dim_aligns.size() == shape.size());
-        strides.resize(shape.size());
-        PrimExpr stride = make_const(shape[0].dtype(), 1);
-        for (size_t i = shape.size(); i != 0; --i) {
-          size_t dim = i - 1;
-          if (info.dim_aligns[dim].align_factor != 0) {
-            PrimExpr factor = make_const(stride.dtype(), 
info.dim_aligns[dim].align_factor);
-            PrimExpr offset = make_const(stride.dtype(), 
info.dim_aligns[dim].align_offset);
-            stride = stride + indexmod(factor + offset - indexmod(stride, 
factor), factor);
-          }
-          strides.Set(dim, stride);
-          stride = stride * shape[dim];
-        }
-      }
-      ObjectPtr<BufferNode> n = make_object<BufferNode>(*buffer.get());
-      n->shape = std::move(shape);
-      n->strides = std::move(strides);
-      info.new_buffer = Buffer(std::move(n));
-      result.push_back(info.new_buffer);
+  Stmt VisitStmt_(const DeclBufferNode* op) final {
+    Buffer new_buffer = RewriteAllocBuffer(op->buffer);
+    auto n = CopyOnWrite(op);
+    n->buffer = std::move(new_buffer);
+    n->body = VisitStmt(op->body);
+    return DeclBuffer(n);
+  }
+
+  Stmt VisitStmt_(const AllocateNode* op) final {
+    Allocate allocate = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
+    auto it = buffer_info_.find(allocate->buffer_var);
+    if (it == buffer_info_.end()) {
+      return std::move(allocate);
     }
-    return result;
+    // Rewrite allocation shape if the corresponding buffer is in the 
buffer_info_
+    // dict and the dtype is consistent, which denotes there are no buffer 
aliasing
+    // and the compaction is safe.
+    const Buffer& new_buffer = it->second.new_buffer;
+    if (op->dtype != new_buffer->dtype) {
+      return std::move(allocate);
+    }
+    Array<PrimExpr> new_shape = GetBufferAllocationShape(new_buffer);
+    auto n = allocate.CopyOnWrite();
+    ICHECK(n->buffer_var.same_as(new_buffer->data));
+    n->extents = new_shape;
+    return std::move(allocate);
+  }
+
+  Buffer RewriteAllocBuffer(const Buffer& buffer) {
+    auto it = buffer_info_.find(buffer->data);
+    if (it != buffer_info_.end()) {
+      return it->second.new_buffer;
+    }
+    return buffer;
   }
 
   void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) const {
-    auto it = buffer_info_.find(*buffer);
+    auto it = buffer_info_.find((*buffer)->data);
     if (it == buffer_info_.end()) {
-      // Skip if the buffer is parameter
       return;
     }
     const BufferAllocInfo& info = it->second;
@@ -539,7 +664,7 @@ class BufferCompactor : public StmtExprMutator {
   }
 
   void RewriteBufferRegion(Buffer* buffer, Region* region) const {
-    auto it = buffer_info_.find(*buffer);
+    auto it = buffer_info_.find((*buffer)->data);
     if (it == buffer_info_.end()) {
       // Skip if the buffer is parameter
       return;
@@ -580,18 +705,16 @@ class BufferCompactor : public StmtExprMutator {
     *match_buffers = std::move(result);
   }
 
-  /*! \brief The allocation information about each buffer. */
-  std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> 
buffer_info_;
+  /*! \brief Map buffer var to the allocation information about each buffer. */
+  std::unordered_map<Var, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> 
buffer_info_;
 };
 
-PrimFunc CompactBufferAllocation(PrimFunc f) {
+PrimFunc CompactBufferAllocation(PrimFunc f, bool is_strict) {
   // Only apply this pass to TIR that is not from TE schedules
   if (!IsFromLegacyTESchedule(f)) {
     PrimFuncNode* fptr = f.CopyOnWrite();
-    std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> region =
-        BufferAccessRegionCollector::Collect(f);
-    std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, 
ObjectPtrEqual>
-        storage_align = StorageAlignCollector::Collect(f);
+    auto region = BufferAccessRegionCollector::Collect(f, 
/*collect_inbound=*/is_strict);
+    auto storage_align = CollectStorageAlignAnnotation(f->body);
     fptr->body = BufferCompactor::Compact(f, region, storage_align);
     return f;
   } else {
@@ -601,9 +724,9 @@ PrimFunc CompactBufferAllocation(PrimFunc f) {
 
 namespace transform {
 
-Pass CompactBufferAllocation() {
+Pass CompactBufferAllocation(bool is_strict) {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
-    return CompactBufferAllocation(std::move(f));
+    return CompactBufferAllocation(std::move(f), is_strict);
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.CompactBufferAllocation", {});
 }
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index e80772fda3..b591016fd8 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -75,6 +75,11 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
       ICHECK(is_no_op(n->body));
       n->body = body;
       body = Stmt(n);
+    } else if (const auto* decl_buffer = s.as<DeclBufferNode>()) {
+      auto n = make_object<DeclBufferNode>(*decl_buffer);
+      ICHECK(is_no_op(n->body));
+      n->body = body;
+      body = Stmt(n);
     } else {
       LOG(FATAL) << "not supported nest type";
     }
@@ -387,6 +392,18 @@ String GetPtrStorageScope(Var buffer_var) {
   return ptr_type->storage_scope;
 }
 
+Array<PrimExpr> GetBufferAllocationShape(const Buffer& buffer) {
+  Array<PrimExpr> alloc_shape = buffer->shape;
+  if (buffer->strides.size()) {
+    ICHECK_EQ(buffer->shape.size(), buffer->strides.size());
+    for (size_t i = buffer->strides.size() - 1; i > 0; --i) {
+      ICHECK(is_zero(floormod(buffer->strides[i - 1], buffer->strides[i])));
+      alloc_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
+    }
+  }
+  return alloc_shape;
+}
+
 Array<PrimExpr> ConvertIndices(const MatchBufferRegion& match_buffer,
                                const Array<PrimExpr>& indices) {
   const Buffer& target = match_buffer->buffer;
@@ -438,11 +455,14 @@ Bool IsFromLegacyTESchedule(PrimFunc f) {
   return from_legacy_te_schedule.value();
 }
 
-Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() {
+Optional<arith::IntConstraints> ConditionalBoundsContext::TrySolveCondition() {
   // extract equations and related vars from condition expression.
   // currently only extract simple integral equations which could be solvable.
   arith::Analyzer analyzer;
-  PrimExpr condition = is_true_branch_ ? condition_ : 
analyzer.Simplify(!condition_);
+  PrimExpr condition = analyzer.Simplify(condition_);
+  if (is_const_int(condition)) {
+    return NullOpt;
+  }
   Array<PrimExpr> equations;
   Array<Var> vars;
   std::function<void(const PrimExpr&)> fvisit = [&equations, &vars, 
&fvisit](const PrimExpr& e) {
@@ -485,7 +505,7 @@ Map<Var, Range> 
ConditionalBoundsContext::GetVarBoundsFromCondition() {
   };
   fvisit(condition);
   if (equations.empty() || vars.empty()) {
-    return Map<Var, Range>();
+    return NullOpt;
   }
   // build dom ranges for related vars
   Map<Var, Range> ranges;
@@ -506,22 +526,35 @@ Map<Var, Range> 
ConditionalBoundsContext::GetVarBoundsFromCondition() {
   }
   // solve constraints
   arith::IntConstraints constraint(vars, ranges, equations);
-  auto result = arith::SolveInequalitiesToRange(constraint);
-  return result->ranges;
+  arith::IntConstraints result = arith::SolveInequalitiesToRange(constraint);
+  return std::move(result);
 }
 
 ConditionalBoundsContext::ConditionalBoundsContext(
     const PrimExpr& condition, std::unordered_map<const VarNode*, 
arith::IntSet>* relax_map,
-    std::unordered_map<const VarNode*, arith::IntSet>* hint_map, bool 
is_true_branch)
+    std::unordered_map<const VarNode*, arith::IntSet>* hint_map,
+    std::vector<PrimExpr>* pending_conditions)
     : condition_(condition),
       relax_map_(relax_map),
       hint_map_(hint_map),
-      is_true_branch_(is_true_branch) {}
+      pending_conditions_(pending_conditions),
+      origin_pending_conditions_num_(pending_conditions->size()) {}
 
 void ConditionalBoundsContext::EnterWithScope() {
-  for (const auto& p : GetVarBoundsFromCondition()) {
-    const auto* var = p.first.get();
-    arith::IntSet new_dom = arith::IntSet::FromRange(p.second);
+  Optional<arith::IntConstraints> constraints = TrySolveCondition();
+  if (!constraints.defined()) {
+    // fail to process the condition, add to unresolved
+    pending_conditions_->push_back(condition_);
+    return;
+  }
+  for (const PrimExpr& unresolved : constraints.value()->relations) {
+    // add partially unresolved conditions
+    pending_conditions_->push_back(unresolved);
+  }
+  // update solved var ranges
+  for (const auto& kv : constraints.value()->ranges) {
+    const VarNode* var = kv.first.get();
+    arith::IntSet new_dom = arith::IntSet::FromRange(kv.second);
     auto relax_it = relax_map_->find(var);
     if (relax_it != relax_map_->end()) {
       // this is a bound for relaxed var
@@ -542,6 +575,7 @@ void ConditionalBoundsContext::EnterWithScope() {
 }
 
 void ConditionalBoundsContext::ExitWithScope() {
+  pending_conditions_->resize(origin_pending_conditions_num_);
   for (const auto& p : origin_map_) {
     const auto* var = p.first;
     auto relax_it = relax_map_->find(var);
@@ -568,6 +602,53 @@ std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const 
AttrStmtNode* op) {
   return std::make_pair(op->value, inner->value);
 }
 
+/*! \brief Collect storage alignment information from annotations. */
+class StorageAlignCollector : public StmtVisitor {
+ private:
+  friend std::unordered_map<Var, StorageAlignAnnotation, ObjectPtrHash, 
ObjectPtrEqual>
+  CollectStorageAlignAnnotation(const Stmt& body);
+
+  /*! \brief For s-stir, the alignment annotations reside in block 
annotations. */
+  void VisitStmt_(const BlockNode* op) final {
+    auto it = op->annotations.find(attr::buffer_dim_align);
+    if (it != op->annotations.end()) {
+      auto storage_align_annotation = 
Downcast<StorageAlignAnnotation>((*it).second);
+      for (const auto& storage_align_tuple : storage_align_annotation) {
+        int buffer_index = storage_align_tuple[0]->value;
+        const Buffer& buffer = op->writes[buffer_index]->buffer;
+        storage_align_[buffer->data].push_back(storage_align_tuple);
+      }
+    }
+    StmtVisitor::VisitStmt_(op);
+  }
+
+  /*! \brief For lowered tir, the alignment annotations reside in allocate 
annotations. */
+  void VisitStmt_(const AllocateNode* op) final {
+    auto it = op->annotations.find(attr::buffer_dim_align);
+    if (it != op->annotations.end()) {
+      auto storage_align_annotation = 
Downcast<StorageAlignAnnotation>((*it).second);
+      for (const auto& storage_align_tuple : storage_align_annotation) {
+        int buffer_index = storage_align_tuple[0]->value;
+        // the first buffer idx info is meaningless for allocate
+        // stmt and should set as negative intentionally.
+        ICHECK_EQ(buffer_index, -1);
+        storage_align_[op->buffer_var].push_back(storage_align_tuple);
+      }
+    }
+    StmtVisitor::VisitStmt_(op);
+  }
+
+  /*! \brief The map from buffer var to its storage alignment information. */
+  std::unordered_map<Var, StorageAlignAnnotation, ObjectPtrHash, 
ObjectPtrEqual> storage_align_;
+};
+
+std::unordered_map<Var, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual>
+CollectStorageAlignAnnotation(const Stmt& body) {
+  StorageAlignCollector collector;
+  collector(body);
+  return std::move(collector.storage_align_);
+}
+
 namespace transform {
 Pass ConvertSSA() {
   auto pass_func = [](IRModule mod, PassContext ctx) {
diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h
index 6915a0e3ac..afaff34472 100644
--- a/src/tir/transforms/ir_utils.h
+++ b/src/tir/transforms/ir_utils.h
@@ -25,6 +25,7 @@
 #define TVM_TIR_TRANSFORMS_IR_UTILS_H_
 
 #include <tvm/arith/int_set.h>
+#include <tvm/arith/int_solver.h>
 #include <tvm/runtime/device_api.h>
 #include <tvm/support/with.h>
 #include <tvm/tir/builtin.h>
@@ -224,6 +225,13 @@ Array<PrimExpr> ConvertIndices(const MatchBufferRegion& 
match_buffer,
  */
 Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& 
region);
 
+/*!
+ * \brief Get stride aware buffer allocation shape from buffer.
+ * \param buffer The buffer object.
+ * \return shape The shape considering buffer strides.
+ */
+Array<PrimExpr> GetBufferAllocationShape(const Buffer& buffer);
+
 /*!
  * \brief Check if a given PrimFunc originated from a TE schedule.
  *
@@ -235,12 +243,12 @@ Region ConvertRegion(const MatchBufferRegion& 
match_buffer, const Region& region
 Bool IsFromLegacyTESchedule(PrimFunc f);
 
 /*!
- *\brief Context helper to update domain map within conditional scope.
- *
- * Assume the condition is `0 <= i && i < 9` and global domain of i is [0, 
20], thus `bounds[i]` is
- * [0, 8]. Then `With<ConditionalBoundsContext> ctx(condition, &relax_map, 
&hint_map, true)` step
- *into scope where dom_map[i] is [0, 8] and `With<ConditionalBoundsContext> 
ctx(condition,
- *&relax_map, &hint_map, false)` step into scope where dom_map[i] is [9, 20]
+ * \brief Context helper to update domain map within conditional scope.
+ * Assume the condition is `0 <= i && i < 9` and domain of i is [0, 20], Then
+ * `With<ConditionalBoundsContext> ctx(condition, &relax_map, &hint_map, 
&constraints)`
+ * step into scope where dom_map[i] is [0, 8]; and
+ * `With<ConditionalBoundsContext> ctx(!condition, &relax_map, &hint_map, 
&constraints)`
+ * step into scope where dom_map[i] is [9, 20]
  */
 class ConditionalBoundsContext {
  private:
@@ -250,17 +258,17 @@ class ConditionalBoundsContext {
    * \param condition The condition holds on true branch.
    * \param relax_map The domain map for relaxed vars to update.
    * \param hint_map The domain map for free vars to update.
-   * \param is_true_branch Whether step into the branch where condition bounds 
holds.
+   * \param pending_conditions The stack of unresolved constraints.
    */
   ConditionalBoundsContext(const PrimExpr& condition,
                            std::unordered_map<const VarNode*, arith::IntSet>* 
relax_map,
                            std::unordered_map<const VarNode*, arith::IntSet>* 
hint_map,
-                           bool is_true_branch);
+                           std::vector<PrimExpr>* pending_constraints);
   void EnterWithScope();
   void ExitWithScope();
 
   /*! \brief Helper to solve related variable's bound within conditional 
scope.*/
-  Map<Var, Range> GetVarBoundsFromCondition();
+  Optional<arith::IntConstraints> TrySolveCondition();
 
   /*! \brief the condition holds on true branch. */
   const PrimExpr& condition_;
@@ -268,10 +276,12 @@ class ConditionalBoundsContext {
   std::unordered_map<const VarNode*, arith::IntSet>* relax_map_;
   /*! \brief domain map for free vars to update */
   std::unordered_map<const VarNode*, arith::IntSet>* hint_map_;
-  /*! \brief whether is on true branch */
-  bool is_true_branch_;
+  /*! \brief unresolved condition stack */
+  std::vector<PrimExpr>* pending_conditions_;
   /*! \brief used to record and restore original var bounds */
   std::unordered_map<const VarNode*, arith::IntSet> origin_map_;
+  /*! \brief used to record unresolved conditions num. */
+  size_t origin_pending_conditions_num_;
 };
 
 // Information of tensor core fragment.
@@ -321,6 +331,18 @@ std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const 
AttrStmtNode* op);
  */
 PrimFunc BindParams(PrimFunc f, const Array<runtime::NDArray>& constants);
 
+/*! \brief The quad used by StorageAlign for (buffer_idx, axis, factor, 
offset) */
+using StorageAlignTuple = Array<Integer>;
+/*! \brief A list of StorageAlignTuple, used by StorageAlign */
+using StorageAlignAnnotation = Array<StorageAlignTuple>;
+/*!
+ * \brief Collect storage alignment annotations for all buffer vars within 
body.
+ * \param body The stmt to collect.
+ * \return The result dict from buffer var to storage align annotations.
+ */
+std::unordered_map<Var, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual>
+CollectStorageAlignAnnotation(const Stmt& body);
+
 }  // namespace tir
 }  // namespace tvm
 #endif  // TVM_TIR_TRANSFORMS_IR_UTILS_H_
diff --git a/src/tir/transforms/lower_opaque_block.cc 
b/src/tir/transforms/lower_opaque_block.cc
index 9a702db69f..86892433b4 100644
--- a/src/tir/transforms/lower_opaque_block.cc
+++ b/src/tir/transforms/lower_opaque_block.cc
@@ -33,6 +33,13 @@ namespace tir {
  * \brief Remove Block to ensure that the TIR can not be scheduled again.
  */
 class OpaqueBlockLower : public StmtExprMutator {
+ public:
+  static Stmt Rewrite(Stmt body) {
+    OpaqueBlockLower lower;
+    lower.storage_align_ = CollectStorageAlignAnnotation(body);
+    return lower(std::move(body));
+  }
+
  private:
   Stmt VisitStmt_(const BlockRealizeNode* op) final {
     // We have convert blocks into opaque blocks in previous passes.
@@ -49,16 +56,22 @@ class OpaqueBlockLower : public StmtExprMutator {
     // Step 3. Handle allocations in reverse order
     for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
       const Buffer& buffer = new_block->alloc_buffers[i - 1];
-      Array<PrimExpr> new_shape = buffer->shape;
-      if (buffer->strides.size()) {
-        ICHECK_EQ(buffer->shape.size(), buffer->strides.size());
-        for (size_t i = buffer->strides.size() - 1; i > 0; --i) {
-          ICHECK(is_zero(floormod(buffer->strides[i - 1], 
buffer->strides[i])));
-          new_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
+      Array<PrimExpr> allocation_shape = GetBufferAllocationShape(buffer);
+      body = DeclBuffer(buffer, std::move(body));
+      Map<String, ObjectRef> allocate_annotations;
+      auto it = storage_align_.find(buffer->data);
+      if (it != storage_align_.end()) {
+        StorageAlignAnnotation allocate_aligns;
+        for (auto tuple : it->second) {
+          ICHECK_EQ(tuple.size(), 4);
+          tuple.Set(0, -1);
+          allocate_aligns.push_back(tuple);
         }
+        allocate_annotations.Set(attr::buffer_dim_align, allocate_aligns);
       }
-      body = DeclBuffer(buffer, std::move(body));
-      body = Allocate(buffer->data, buffer->dtype, new_shape, const_true(), 
std::move(body));
+
+      body = Allocate(buffer->data, buffer->dtype, allocation_shape, 
const_true(), std::move(body),
+                      allocate_annotations);
     }
     // Step 4. Handle annotations, block annotations are not preserved by 
default.
     std::vector<std::pair<std::string, PrimExpr>> pragma_attrs;
@@ -181,13 +194,16 @@ class OpaqueBlockLower : public StmtExprMutator {
 
   /*! \brief Attr keys to preserve into loop annotations. */
   std::unordered_set<std::string> preserved_annotations_;
+
+  /*! \brief The map from buffer var to its storage alignment information. */
+  std::unordered_map<Var, StorageAlignAnnotation, ObjectPtrHash, 
ObjectPtrEqual> storage_align_;
 };
 
 PrimFunc LowerOpaqueBlock(PrimFunc f) {
   // Only apply this pass to TIR that is not from TE schedules
   if (!IsFromLegacyTESchedule(f)) {
     auto fptr = f.CopyOnWrite();
-    fptr->body = OpaqueBlockLower()(std::move(fptr->body));
+    fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body));
     return f;
   } else {
     return f;
diff --git a/tests/python/unittest/test_tir_buffer.py 
b/tests/python/unittest/test_tir_buffer.py
index 95ad81db88..e3b63d9315 100644
--- a/tests/python/unittest/test_tir_buffer.py
+++ b/tests/python/unittest/test_tir_buffer.py
@@ -99,29 +99,6 @@ def test_buffer_offset_of():
     tvm.ir.assert_structural_equal(offset, [n * 2 + 103])
 
 
-def test_buffer_vload_nullptr():
-    var = tvm.tir.Var("v", dtype="int32")
-    buf = tvm.tir.decl_buffer((1,), name="buf")
-    buf_load = tvm.tir.expr.BufferLoad(buffer=buf, 
indices=tvm.runtime.convert([0]))
-    buf_load_stmt = tvm.tir.stmt.Evaluate(buf_load)
-    for_loop = tvm.tir.stmt.For(
-        loop_var=var, kind=0, min_val=0, extent=tvm.tir.Cast("int32", 
buf_load), body=buf_load_stmt
-    )
-    buf_func = tvm.tir.PrimFunc(params={}, body=for_loop)
-    mod = tvm.IRModule({"main": buf_func})
-    # Trigger nullptr buffer bug by pass
-    with pytest.raises(tvm.error.TVMError) as cm:
-        mod = tvm.transform.Sequential(
-            [
-                tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(),
-                tvm.tir.transform.CompactBufferAllocation(),
-                tvm.tir.transform.LowerOpaqueBlock(),
-                tvm.tir.transform.FlattenBuffer(),
-            ]
-        )(mod)
-        assert "(n != nullptr) is false" in str(cm.execption)
-
-
 def test_buffer_index_merge_mult_mod():
     m = te.size_var("m")
     n = te.size_var("n")
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py 
b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
index e90539f3ef..c860bc0c55 100644
--- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -17,713 +17,736 @@
 import tvm
 import tvm.testing
 from tvm import te
+from tvm import tir
 from tvm.script import tir as T
 
 
-def _check(original, transformed):
-    func = original
-    mod = tvm.IRModule.from_expr(func)
-    mod = tvm.tir.transform.CompactBufferAllocation()(mod)
-    mod = tvm.tir.transform.Simplify()(mod)
-    transformed = 
tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(transformed))["main"]
-    tvm.ir.assert_structural_equal(mod["main"], transformed)
+class BaseCompactTest:
+    """Base testcase class. The inherit testcase should include:
+    - `before` and `expected` primfunc used to check structural equality for 
the transformation.
+    - `is_lower_order_free` tag, defaults to True, denotes that we would check
+       (LowerOpaqueBlock . CompactBufferAllocation)(before) ==
+       (CompactBufferAllocation . LowerOpaqueBlock)(before)
+    - `is_strict` tag, defaults to True, controls the `is_strict` option of 
the compaction pass.
+    """
+
+    def test_compact(self):
+        is_lower_order_free = getattr(self, "is_lower_order_free", True)
+        is_strict = getattr(self, "is_strict_mode", True)
+
+        before = tvm.IRModule.from_expr(self.before)
+        expected = tvm.IRModule.from_expr(self.expected)
+        simplify = tvm.transform.Sequential([tir.transform.Simplify(), 
tir.transform.RemoveNoOp()])
+        after = 
simplify(tir.transform.CompactBufferAllocation(is_strict=is_strict)(before))
+        expected = simplify(expected)
+        try:
+            tvm.ir.assert_structural_equal(after, expected)
+        except ValueError as err:
+            script = tvm.IRModule(
+                {"expected": expected["main"], "after": after["main"], 
"before": before["main"]}
+            ).script()
+            raise ValueError(
+                f"Function after simplification did not match 
expected:\n{script}"
+            ) from err
+
+        if not is_lower_order_free:
+            return
+        lower_before_compact = tir.transform.LowerOpaqueBlock()(before)
+        lower_before_compact = 
tir.transform.CompactBufferAllocation(is_strict=is_strict)(
+            lower_before_compact
+        )
+        lower_before_compact = simplify(lower_before_compact)
+        lower_after_compact = tir.transform.LowerOpaqueBlock()(after)
+        lower_after_compact = simplify(lower_after_compact)
+        try:
+            tvm.ir.assert_structural_equal(lower_before_compact, 
lower_after_compact)
+        except ValueError as err:
+            script = tvm.IRModule(
+                {
+                    "lower_before_compact": lower_before_compact["main"],
+                    "lower_after_compact": lower_after_compact["main"],
+                    "before": before["main"],
+                }
+            ).script()
+            raise ValueError(
+                f"Function after simplification did not match 
expected:\n{script}"
+            ) from err
+
+
+class TestElemwise(BaseCompactTest):
+    @T.prim_func
+    def before(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
+        C = T.match_buffer(c, (16, 16), "float32")
+        for i in range(0, 16):
+            with T.block():
+                T.reads(A[i, 0:16])
+                T.writes(C[i, 0:16])
+                B = T.alloc_buffer((16, 16), "float32")
+                for j in range(0, 16):
+                    with T.block():
+                        T.reads(A[i, j])
+                        T.writes(B[i, j])
+                        B[i, j] = A[i, j] + 1.0
+                for j in range(0, 16):
+                    with T.block():
+                        T.reads(B[i, j])
+                        T.writes(C[i, j])
+                        C[i, j] = B[i, j] * 2.0
 
+    @T.prim_func
+    def expected(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
+        C = T.match_buffer(c, (16, 16), "float32")
+        for i in range(0, 16):
+            with T.block():
+                T.reads(A[i, 0:16])
+                T.writes(C[i, 0:16])
+                B = T.alloc_buffer((1, 16), "float32")
+                for j in range(0, 16):
+                    with T.block():
+                        T.reads(A[i, j])
+                        T.writes(B[0, j])
+                        B[0, j] = A[i, j] + 1.0
+                for j in range(0, 16):
+                    with T.block():
+                        T.reads(B[0, j])
+                        T.writes(C[i, j])
+                        C[i, j] = B[0, j] * 2.0
 
[email protected]_func
-def elementwise_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    C = T.match_buffer(c, (16, 16), "float32")
-    for i in range(0, 16):
-        with T.block():
-            T.reads(A[i, 0:16])
-            T.writes(C[i, 0:16])
-            B = T.alloc_buffer((16, 16), "float32")
-            for j in range(0, 16):
-                with T.block():
-                    T.reads(A[i, j])
-                    T.writes(B[i, j])
+
+class TestUnschedulableFunc(BaseCompactTest):
+    @T.prim_func
+    def before(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
+        C = T.match_buffer(c, (16, 16), "float32")
+        for i in range(0, 16):
+            with T.block():
+                T.reads(A[i, 0:16])
+                T.writes(C[i, 0:16])
+                B = T.alloc_buffer((16, 16), "float32")
+                for j in range(0, 16):
+                    T.evaluate(T.call_extern("dummy_extern_function", B.data, 
dtype="int32"))
                     B[i, j] = A[i, j] + 1.0
-            for j in range(0, 16):
-                with T.block():
-                    T.reads(B[i, j])
-                    T.writes(C[i, j])
+                for j in range(0, 16):
                     C[i, j] = B[i, j] * 2.0
 
+    expected = before
 
[email protected]_func
-def compacted_elementwise_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    C = T.match_buffer(c, (16, 16), "float32")
-    for i in range(0, 16):
-        with T.block():
-            T.reads(A[i, 0:16])
-            T.writes(C[i, 0:16])
-            B = T.alloc_buffer((1, 16), "float32")
-            for j in range(0, 16):
-                with T.block():
-                    T.reads(A[i, j])
-                    T.writes(B[0, j])
-                    B[0, j] = A[i, j] + 1.0
-            for j in range(0, 16):
-                with T.block():
-                    T.reads(B[0, j])
-                    T.writes(C[i, j])
-                    C[i, j] = B[0, j] * 2.0
 
+class TestParamBufferAccess(BaseCompactTest):
+    @T.prim_func
+    def before(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (20, 20), "float32")
+        B = T.match_buffer(c, (20, 20), "float32")
+        for i in range(0, 16):
+            with T.block():
+                T.reads(A[i, 0:16])
+                T.writes(B[i, 0:16])
+                for j in range(0, 16):
+                    with T.block():
+                        T.reads(A[i, j])
+                        T.writes(B[i, j])
+                        B[i, j] = A[i, j] + 1.0
 
[email protected]_func
-def unschedulable_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    C = T.match_buffer(c, (16, 16), "float32")
-    for i in range(0, 16):
-        with T.block():
-            T.reads(A[i, 0:16])
-            T.writes(C[i, 0:16])
-            B = T.alloc_buffer((16, 16), "float32")
-            for j in range(0, 16):
-                T.evaluate(T.call_extern("dummy_extern_function", B.data, 
dtype="int32"))
-                B[i, j] = A[i, j] + 1.0
-            for j in range(0, 16):
-                C[i, j] = B[i, j] * 2.0
-
-
[email protected]_func
-def param_buffer_access_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (20, 20), "float32")
-    B = T.match_buffer(c, (20, 20), "float32")
-    for i in range(0, 16):
-        with T.block():
-            T.reads(A[i, 0:16])
-            T.writes(B[i, 0:16])
-            for j in range(0, 16):
-                with T.block():
-                    T.reads(A[i, j])
-                    T.writes(B[i, j])
-                    B[i, j] = A[i, j] + 1.0
+    expected = before
 
 
[email protected]_func
-def shared_mem_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    C = T.match_buffer(c, (16, 16), "float32")
-    for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
-        for i1 in T.thread_binding(0, 2, thread="vthread"):
-            for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
-                with T.block():
-                    T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
-                    T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
-                    B = T.alloc_buffer((16, 16), "float32", scope="shared")
-                    for j in range(0, 16):
-                        with T.block():
-                            T.reads(A[i0 * 8 + i1 * 4 + i2, j])
-                            T.writes(B[i0 * 8 + i1 * 4 + i2, j])
-                            B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + 
i2, j] + 1.0
-                    for j in range(0, 16):
-                        with T.block():
-                            T.reads(B[i0 * 8 + i1 * 4 + i2, j])
-                            T.writes(C[i0 * 8 + i1 * 4 + i2, j])
-                            C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + 
i2, j] * 2.0
-
-
[email protected]_func
-def compacted_shared_mem_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    C = T.match_buffer(c, (16, 16), "float32")
-    for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
-        for i1 in T.thread_binding(0, 2, thread="vthread"):
-            for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
-                with T.block():
-                    T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
-                    T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
-                    B = T.alloc_buffer((8, 16), "float32", scope="shared")
-                    for j in range(0, 16):
-                        with T.block():
-                            T.reads(A[i0 * 8 + i1 * 4 + i2, j])
-                            T.writes(B[i1 * 4 + i2, j])
-                            B[i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 
1.0
-                    for j in range(0, 16):
-                        with T.block():
-                            T.reads(B[i1 * 4 + i2, j])
-                            T.writes(C[i0 * 8 + i1 * 4 + i2, j])
-                            C[i0 * 8 + i1 * 4 + i2, j] = B[i1 * 4 + i2, j] * 
2.0
-
-
[email protected]_func
-def warp_mem_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    C = T.match_buffer(c, (16, 16), "float32")
-    for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
-        for i1 in T.thread_binding(0, 2, thread="vthread"):
-            for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
-                with T.block():
-                    T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
-                    T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
-                    B = T.alloc_buffer((16, 16), "float32", scope="warp")
-                    for j in range(0, 16):
-                        with T.block():
-                            T.reads(A[i0 * 8 + i1 * 4 + i2, j])
-                            T.writes(B[i0 * 8 + i1 * 4 + i2, j])
-                            B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + 
i2, j] + 1.0
-                    for j in range(0, 16):
-                        with T.block():
-                            T.reads(B[i0 * 8 + i1 * 4 + i2, j])
-                            T.writes(C[i0 * 8 + i1 * 4 + i2, j])
-                            C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + 
i2, j] * 2.0
-
-
[email protected]_func
-def compacted_warp_mem_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    C = T.match_buffer(c, (16, 16), "float32")
-    for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
-        for i1 in T.thread_binding(0, 2, thread="vthread"):
-            for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
-                with T.block():
-                    T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
-                    T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
-                    B = T.alloc_buffer((4, 16), "float32", scope="warp")
-                    for j in range(0, 16):
-                        with T.block():
-                            T.reads(A[i0 * 8 + i1 * 4 + i2, j])
-                            T.writes(B[i2, j])
-                            B[i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0
-                    for j in range(0, 16):
-                        with T.block():
-                            T.reads(B[i2, j])
-                            T.writes(C[i0 * 8 + i1 * 4 + i2, j])
-                            C[i0 * 8 + i1 * 4 + i2, j] = B[i2, j] * 2.0
+class TestSharedMem(BaseCompactTest):
+    @T.prim_func
+    def before(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
+        C = T.match_buffer(c, (16, 16), "float32")
+        for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
+            for i1 in T.thread_binding(0, 2, thread="vthread"):
+                for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
+                    with T.block():
+                        T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
+                        T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
+                        B = T.alloc_buffer((16, 16), "float32", scope="shared")
+                        for j in range(0, 16):
+                            with T.block():
+                                T.reads(A[i0 * 8 + i1 * 4 + i2, j])
+                                T.writes(B[i0 * 8 + i1 * 4 + i2, j])
+                                B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 
+ i2, j] + 1.0
+                        for j in range(0, 16):
+                            with T.block():
+                                T.reads(B[i0 * 8 + i1 * 4 + i2, j])
+                                T.writes(C[i0 * 8 + i1 * 4 + i2, j])
+                                C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 
+ i2, j] * 2.0
 
+    @T.prim_func
+    def expected(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
+        C = T.match_buffer(c, (16, 16), "float32")
+        for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
+            for i1 in T.thread_binding(0, 2, thread="vthread"):
+                for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
+                    with T.block():
+                        T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
+                        T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
+                        B = T.alloc_buffer((8, 16), "float32", scope="shared")
+                        for j in range(0, 16):
+                            with T.block():
+                                T.reads(A[i0 * 8 + i1 * 4 + i2, j])
+                                T.writes(B[i1 * 4 + i2, j])
+                                B[i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] 
+ 1.0
+                        for j in range(0, 16):
+                            with T.block():
+                                T.reads(B[i1 * 4 + i2, j])
+                                T.writes(C[i0 * 8 + i1 * 4 + i2, j])
+                                C[i0 * 8 + i1 * 4 + i2, j] = B[i1 * 4 + i2, j] 
* 2.0
 
[email protected]_func
-def symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None:
-    A = T.match_buffer(a, (n * 8,), "float32")
-    C = T.match_buffer(c, (n * 8,), "float32")
-    for i in range(0, n):
-        with T.block():
-            T.reads(A[i * 8 : i * 8 + 8])
-            T.writes(C[i * 8 : i * 8 + 8])
-            B = T.alloc_buffer((n * 8,), "float32")
-            for j in range(0, 8):
-                with T.block():
-                    T.reads(A[i * 8 + j])
-                    T.writes(B[i * 8 + j])
-                    B[i * 8 + j] = A[i * 8 + j] + 1.0
-            for j in range(0, 8):
-                with T.block():
-                    T.reads(B[i * 8 + j])
-                    T.writes(C[i * 8 + j])
-                    C[i * 8 + j] = B[i * 8 + j] * 2.0
 
+class TestWrapMem(BaseCompactTest):
+    @T.prim_func
+    def before(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
+        C = T.match_buffer(c, (16, 16), "float32")
+        for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
+            for i1 in T.thread_binding(0, 2, thread="vthread"):
+                for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
+                    with T.block():
+                        T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
+                        T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
+                        B = T.alloc_buffer((16, 16), "float32", scope="warp")
+                        for j in range(0, 16):
+                            with T.block():
+                                T.reads(A[i0 * 8 + i1 * 4 + i2, j])
+                                T.writes(B[i0 * 8 + i1 * 4 + i2, j])
+                                B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 
+ i2, j] + 1.0
+                        for j in range(0, 16):
+                            with T.block():
+                                T.reads(B[i0 * 8 + i1 * 4 + i2, j])
+                                T.writes(C[i0 * 8 + i1 * 4 + i2, j])
+                                C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 
+ i2, j] * 2.0
 
[email protected]_func
-def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None:
-    A = T.match_buffer(a, (n * 8,), "float32")
-    C = T.match_buffer(c, (n * 8,), "float32")
-    for i in range(0, n):
-        with T.block():
-            T.reads(A[i * 8 : i * 8 + 8])
-            T.writes(C[i * 8 : i * 8 + 8])
-            B = T.alloc_buffer((8,), "float32")
-            for j in range(0, 8):
-                with T.block():
-                    T.reads(A[i * 8 + j])
-                    T.writes(B[j])
-                    B[j] = A[i * 8 + j] + 1.0
-            for j in range(0, 8):
-                with T.block():
-                    T.reads(B[j])
-                    T.writes(C[i * 8 + j])
-                    C[i * 8 + j] = B[j] * 2.0
+    @T.prim_func
+    def expected(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
+        C = T.match_buffer(c, (16, 16), "float32")
+        for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
+            for i1 in T.thread_binding(0, 2, thread="vthread"):
+                for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
+                    with T.block():
+                        T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
+                        T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
+                        B = T.alloc_buffer((4, 16), "float32", scope="warp")
+                        for j in range(0, 16):
+                            with T.block():
+                                T.reads(A[i0 * 8 + i1 * 4 + i2, j])
+                                T.writes(B[i2, j])
+                                B[i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0
+                        for j in range(0, 16):
+                            with T.block():
+                                T.reads(B[i2, j])
+                                T.writes(C[i0 * 8 + i1 * 4 + i2, j])
+                                C[i0 * 8 + i1 * 4 + i2, j] = B[i2, j] * 2.0
 
 
[email protected]_func
-def complex_func(a: T.handle, c: T.handle, n: T.int32) -> None:
-    A = T.match_buffer(a, (8, 8), "float32")
-    C = T.match_buffer(c, (8, 8), "float32")
-    for i in range(0, 8):
-        with T.block():
-            T.reads(A[0, 8])
-            T.writes(C[0, 8])
-            B = T.alloc_buffer((8, 8), "float32")
-            for j in range(0, 4):
-                with T.block():
-                    D = T.alloc_buffer((8, 8), "float32")
-                    T.reads(A[i, j])
-                    T.writes(B[i, j])
-                    for k in range(4, 8):
-                        D[k, j] = 1.0
-                    for k in range(2, 4):
-                        B[i, j] = A[i, j] + D[k, j]
-            for j in range(3, 5):
-                with T.block():
-                    T.reads(B[i, j])
-                    T.writes(C[i, j])
-                    C[i, j] = B[i, j]
-            for j in range(6, 8):
-                with T.block():
-                    T.reads(B[i, j])
-                    T.writes(C[i, j])
-                    C[i, j] = B[i, j]
+class TestSymbolic(BaseCompactTest):
+    @T.prim_func
+    def before(a: T.handle, c: T.handle, n: T.int32) -> None:
+        A = T.match_buffer(a, (n * 8,), "float32")
+        C = T.match_buffer(c, (n * 8,), "float32")
+        for i in range(0, n):
+            with T.block():
+                T.reads(A[i * 8 : i * 8 + 8])
+                T.writes(C[i * 8 : i * 8 + 8])
+                B = T.alloc_buffer((n * 8,), "float32")
+                for j in range(0, 8):
+                    with T.block():
+                        T.reads(A[i * 8 + j])
+                        T.writes(B[i * 8 + j])
+                        B[i * 8 + j] = A[i * 8 + j] + 1.0
+                for j in range(0, 8):
+                    with T.block():
+                        T.reads(B[i * 8 + j])
+                        T.writes(C[i * 8 + j])
+                        C[i * 8 + j] = B[i * 8 + j] * 2.0
+
+    @T.prim_func
+    def expected(a: T.handle, c: T.handle, n: T.int32) -> None:
+        A = T.match_buffer(a, (n * 8,), "float32")
+        C = T.match_buffer(c, (n * 8,), "float32")
+        for i in range(0, n):
+            with T.block():
+                T.reads(A[i * 8 : i * 8 + 8])
+                T.writes(C[i * 8 : i * 8 + 8])
+                B = T.alloc_buffer((8,), "float32")
+                for j in range(0, 8):
+                    with T.block():
+                        T.reads(A[i * 8 + j])
+                        T.writes(B[j])
+                        B[j] = A[i * 8 + j] + 1.0
+                for j in range(0, 8):
+                    with T.block():
+                        T.reads(B[j])
+                        T.writes(C[i * 8 + j])
+                        C[i * 8 + j] = B[j] * 2.0
 
 
[email protected]_func
-def compacted_complex_func(a: T.handle, c: T.handle, n: T.int32) -> None:
-    A = T.match_buffer(a, (8, 8), "float32")
-    C = T.match_buffer(c, (8, 8), "float32")
-    for i in range(0, 8):
-        with T.block():
-            T.reads(A[0, 8])
-            T.writes(C[0, 8])
-            B = T.alloc_buffer((1, 8), "float32")
-            for j in range(0, 4):
-                with T.block():
-                    D = T.alloc_buffer((6, 1), "float32")
-                    T.reads(A[i, j])
-                    T.writes(B[0, j])
-                    for k in range(4, 8):
-                        D[k - 2, 0] = 1.0
-                    for k in range(2, 4):
-                        B[0, j] = A[i, j] + D[k - 2, 0]
-            for j in range(3, 5):
-                with T.block():
-                    T.reads(B[0, j])
-                    T.writes(C[i, j])
-                    C[i, j] = B[0, j]
-            for j in range(6, 8):
-                with T.block():
-                    T.reads(B[0, j])
-                    T.writes(C[i, j])
-                    C[i, j] = B[0, j]
+class TestComplexFunc(BaseCompactTest):
+    @T.prim_func
+    def before(a: T.handle, c: T.handle, n: T.int32) -> None:
+        A = T.match_buffer(a, (8, 8), "float32")
+        C = T.match_buffer(c, (8, 8), "float32")
+        for i in range(0, 8):
+            with T.block():
+                T.reads(A[0, 8])
+                T.writes(C[0, 8])
+                B = T.alloc_buffer((8, 8), "float32")
+                for j in range(0, 4):
+                    with T.block():
+                        D = T.alloc_buffer((8, 8), "float32")
+                        T.reads(A[i, j])
+                        T.writes(B[i, j])
+                        for k in range(4, 8):
+                            D[k, j] = 1.0
+                        for k in range(2, 4):
+                            B[i, j] = A[i, j] + D[k, j]
+                for j in range(3, 5):
+                    with T.block():
+                        T.reads(B[i, j])
+                        T.writes(C[i, j])
+                        C[i, j] = B[i, j]
+                for j in range(6, 8):
+                    with T.block():
+                        T.reads(B[i, j])
+                        T.writes(C[i, j])
+                        C[i, j] = B[i, j]
+
+    @T.prim_func
+    def expected(a: T.handle, c: T.handle, n: T.int32) -> None:
+        A = T.match_buffer(a, (8, 8), "float32")
+        C = T.match_buffer(c, (8, 8), "float32")
+        for i in range(0, 8):
+            with T.block():
+                T.reads(A[0, 8])
+                T.writes(C[0, 8])
+                B = T.alloc_buffer((1, 8), "float32")
+                for j in range(0, 4):
+                    with T.block():
+                        D = T.alloc_buffer((6, 1), "float32")
+                        T.reads(A[i, j])
+                        T.writes(B[0, j])
+                        for k in range(4, 8):
+                            D[k - 2, 0] = 1.0
+                        for k in range(2, 4):
+                            B[0, j] = A[i, j] + D[k - 2, 0]
+                for j in range(3, 5):
+                    with T.block():
+                        T.reads(B[0, j])
+                        T.writes(C[i, j])
+                        C[i, j] = B[0, j]
+                for j in range(6, 8):
+                    with T.block():
+                        T.reads(B[0, j])
+                        T.writes(C[i, j])
+                        C[i, j] = B[0, j]
 
 
[email protected]_func
-def match_buffer_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16))
-    C = T.match_buffer(c, (16, 16))
-    for i in range(0, 16):
-        with T.block():
-            A0 = T.match_buffer(A[i, 0:16], (16))
-            C0 = T.match_buffer(C[i, 0:16], (16))
-            B = T.alloc_buffer((16, 16))
+class TestMatchBuffer(BaseCompactTest):
+    is_lower_order_free = False
+
+    @T.prim_func
+    def before(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16))
+        C = T.match_buffer(c, (16, 16))
+        for i in range(0, 16):
             with T.block():
-                B0 = T.match_buffer(B[i, 0:16], (16))
+                A0 = T.match_buffer(A[i, 0:16], (16))
+                C0 = T.match_buffer(C[i, 0:16], (16))
+                B = T.alloc_buffer((16, 16))
+                with T.block():
+                    B0 = T.match_buffer(B[i, 0:16], (16))
+                    for j in range(0, 16):
+                        with T.block():
+                            A1 = T.match_buffer(A0[j], ())
+                            B1 = T.match_buffer(B0[j], ())
+                            B1[()] = A1[()] + 1.0
                 for j in range(0, 16):
                     with T.block():
-                        A1 = T.match_buffer(A0[j], ())
-                        B1 = T.match_buffer(B0[j], ())
-                        B1[()] = A1[()] + 1.0
-            for j in range(0, 16):
+                        C1 = T.match_buffer(C0[j], ())
+                        B2 = T.match_buffer(B[i, j], ())
+                        C1[()] = B2[()] * 2.0
+
+    @T.prim_func
+    def expected(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16))
+        C = T.match_buffer(c, (16, 16))
+        for i in range(0, 16):
+            with T.block():
+                A0 = T.match_buffer(A[i, 0:16], (16))
+                C0 = T.match_buffer(C[i, 0:16], (16))
+                B = T.alloc_buffer((1, 16))
                 with T.block():
-                    C1 = T.match_buffer(C0[j], ())
-                    B2 = T.match_buffer(B[i, j], ())
-                    C1[()] = B2[()] * 2.0
+                    B0 = T.match_buffer(B[0, 0:16], (16))
+                    for j in range(0, 16):
+                        with T.block():
+                            A1 = T.match_buffer(A0[j], ())
+                            B1 = T.match_buffer(B0[j], ())
+                            B1[()] = A1[()] + 1.0
+                for j in range(0, 16):
+                    with T.block():
+                        C1 = T.match_buffer(C0[j], ())
+                        B2 = T.match_buffer(B[0, j], ())
+                        C1[()] = B2[()] * 2.0
 
 
[email protected]_func
-def compacted_match_buffer_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16))
-    C = T.match_buffer(c, (16, 16))
-    for i in range(0, 16):
-        with T.block():
-            A0 = T.match_buffer(A[i, 0:16], (16))
-            C0 = T.match_buffer(C[i, 0:16], (16))
-            B = T.alloc_buffer((1, 16))
+class TestStorageAlign(BaseCompactTest):
+    @T.prim_func
+    def before(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
+        C = T.match_buffer(c, (16, 16), "float32")
+        for i in range(0, 16):
             with T.block():
-                B0 = T.match_buffer(B[0, 0:16], (16))
+                T.reads(A[i, 0:16])
+                T.writes(C[i, 0:16])
+                B = T.alloc_buffer((16, 16), "float32")
                 for j in range(0, 16):
                     with T.block():
-                        A1 = T.match_buffer(A0[j], ())
-                        B1 = T.match_buffer(B0[j], ())
-                        B1[()] = A1[()] + 1.0
-            for j in range(0, 16):
-                with T.block():
-                    C1 = T.match_buffer(C0[j], ())
-                    B2 = T.match_buffer(B[0, j], ())
-                    C1[()] = B2[()] * 2.0
+                        T.reads(A[i, j])
+                        T.writes(B[i, j])
+                        T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]})
+                        B[i, j] = A[i, j] + 1.0
+                for j in range(0, 16):
+                    with T.block():
+                        T.reads(B[i, j])
+                        T.writes(C[i, j])
+                        C[i, j] = B[i, j] * 2.0
 
+    @T.prim_func
+    def expected(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
+        C = T.match_buffer(c, (16, 16), "float32")
+        for i in range(0, 16):
+            with T.block():
+                T.reads(A[i, 0:16])
+                T.writes(C[i, 0:16])
+                B = T.alloc_buffer((1, 16), strides=(31, 1), dtype="float32")
+                for j in range(0, 16):
+                    with T.block():
+                        T.reads(A[i, j])
+                        T.writes(B[0, j])
+                        T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]})
+                        B[0, j] = A[i, j] + 1.0
+                for j in range(0, 16):
+                    with T.block():
+                        T.reads(B[0, j])
+                        T.writes(C[i, j])
+                        C[i, j] = B[0, j] * 2.0
 
[email protected]_func
-def storage_align_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    C = T.match_buffer(c, (16, 16), "float32")
-    for i in range(0, 16):
+
+class TestPaddingPattern(BaseCompactTest):
+    @T.prim_func
+    def before(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (16, 16), "float32")
+        C = T.match_buffer(c, (20, 20), "float32")
         with T.block():
-            T.reads(A[i, 0:16])
-            T.writes(C[i, 0:16])
-            B = T.alloc_buffer((16, 16), "float32")
-            for j in range(0, 16):
+            B = T.alloc_buffer((20, 20), dtype="float32")
+            for i, j in T.grid(16, 16):
                 with T.block():
-                    T.reads(A[i, j])
-                    T.writes(B[i, j])
-                    T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]})
-                    B[i, j] = A[i, j] + 1.0
-            for j in range(0, 16):
+                    B[i, j] = A[i, j]
+            for i, j in T.grid(20, 20):
                 with T.block():
-                    T.reads(B[i, j])
-                    T.writes(C[i, j])
-                    C[i, j] = B[i, j] * 2.0
-
+                    C[i, j] = T.if_then_else(
+                        2 <= i and i < 18 and 2 <= j and j < 18,
+                        B[i - 2, j - 2],
+                        0.0,
+                        dtype="float32",
+                    )
 
[email protected]_func
-def compacted_storage_align_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    C = T.match_buffer(c, (16, 16), "float32")
-    for i in range(0, 16):
+    @T.prim_func
+    def expected(a: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, [16, 16], dtype="float32")
+        C = T.match_buffer(c, [20, 20], dtype="float32")
         with T.block():
-            T.reads(A[i, 0:16])
-            T.writes(C[i, 0:16])
-            B = T.alloc_buffer((1, 16), strides=(31, 1), dtype="float32")
-            for j in range(0, 16):
+            B = T.alloc_buffer([16, 16], dtype="float32")
+            for i, j in T.grid(16, 16):
                 with T.block():
-                    T.reads(A[i, j])
-                    T.writes(B[0, j])
-                    T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]})
-                    B[0, j] = A[i, j] + 1.0
-            for j in range(0, 16):
+                    B[i, j] = A[i, j]
+            for i, j in T.grid(20, 20):
                 with T.block():
-                    T.reads(B[0, j])
-                    T.writes(C[i, j])
-                    C[i, j] = B[0, j] * 2.0
-
-
[email protected]_func
-def padding_pattern_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    C = T.match_buffer(c, (20, 20), "float32")
-    with T.block():
-        B = T.alloc_buffer((20, 20), dtype="float32")
-        for i, j in T.grid(16, 16):
-            with T.block():
-                B[i, j] = A[i, j]
-        for i, j in T.grid(20, 20):
-            with T.block():
-                C[i, j] = T.if_then_else(
-                    2 <= i and i < 18 and 2 <= j and j < 18,
-                    B[i - 2, j - 2],
-                    0.0,
-                    dtype="float32",
-                )
+                    C[i, j] = T.if_then_else(
+                        2 <= i and i < 18 and 2 <= j and j < 18,
+                        B[i - 2, j - 2],
+                        0.0,
+                        dtype="float32",
+                    )
 
 
[email protected]_func
-def compacted_padding_pattern_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, [16, 16], dtype="float32")
-    C = T.match_buffer(c, [20, 20], dtype="float32")
-    with T.block():
-        B = T.alloc_buffer([16, 16], dtype="float32")
-        for i, j in T.grid(16, 16):
-            with T.block():
-                B[i, j] = A[i, j]
-        for i, j in T.grid(20, 20):
-            with T.block():
-                C[i, j] = T.if_then_else(
-                    2 <= i and i < 18 and 2 <= j and j < 18, B[i - 2, j - 2], 
0.0, dtype="float32"
+class TestPaddingPatternInlined(BaseCompactTest):
+    @T.prim_func
+    def before(a: T.handle, b: T.handle) -> None:
+        X = T.match_buffer(a, [224, 224], dtype="float32")
+        Y = T.match_buffer(b, [224, 224], dtype="float32")
+        cache = T.alloc_buffer([224, 224], dtype="float32")
+        for h, w in T.grid(224, 224):
+            with T.block("cache"):
+                cache[h, w] = X[h, w]
+        for h, w, kh, kw in T.grid(224, 224, 3, 3):
+            with T.block("compute"):
+                Y[h, w] = T.max(
+                    Y[h, w],
+                    T.if_then_else(
+                        T.likely(1 <= h + kh, dtype="bool")
+                        and T.likely(h + kh < 225, dtype="bool")
+                        and T.likely(1 <= w + kw, dtype="bool")
+                        and T.likely(w + kw < 225, dtype="bool"),
+                        cache[h + kh - 1, w + kw - 1],
+                        0.0,
+                        dtype="float32",
+                    ),
                 )
 
+    @T.prim_func
+    def expected(X: T.Buffer((224, 224), "float32"), Y: T.Buffer((224, 224), 
"float32")) -> None:
+        cache = T.alloc_buffer([224, 224], dtype="float32")
+        for h, w in T.grid(224, 224):
+            with T.block("cache"):
+                cache[h, w] = X[h, w]
+        for h, w, kh, kw in T.grid(224, 224, 3, 3):
+            with T.block("compute"):
+                Y[h, w] = T.max(
+                    Y[h, w],
+                    T.if_then_else(
+                        T.likely(1 <= h + kh, dtype="bool")
+                        and T.likely(h + kh < 225, dtype="bool")
+                        and T.likely(1 <= w + kw, dtype="bool")
+                        and T.likely(w + kw < 225, dtype="bool"),
+                        cache[h + kh - 1, w + kw - 1],
+                        0.0,
+                        dtype="float32",
+                    ),
+                )
+
+
+class TestMemAccessInBranch(BaseCompactTest):
+    @T.prim_func
+    def before(a: T.handle) -> None:
+        A = T.match_buffer(a, (224, 224), "float32")
+        with T.block():
+            B1 = T.alloc_buffer((224, 224), dtype="float32")
+            B2 = T.alloc_buffer((224, 224), dtype="float32")
+            B3 = T.alloc_buffer((224, 224), dtype="float32")
+            B4 = T.alloc_buffer((224, 224), dtype="float32")
+            for i in range(0, 224):
+                for j in range(0, 224):
+                    with T.block():
+                        if i < 112 and j < 112:
+                            B1[i, j] = A[i, j] * 2.0
+                        else:
+                            B2[i, j] = A[i, j] + 3.0
+            for i in range(0, 224):
+                for j in range(0, 224):
+                    with T.block():
+                        if i < 112 or j < 112:
+                            B3[i, j] = A[i, j] * 2.0
+                        else:
+                            B4[i, j] = A[i, j] + 3.0
 
[email protected]_func
-def padding_pattern_inlined(a: T.handle, b: T.handle) -> None:
-    X = T.match_buffer(a, [224, 224], dtype="float32")
-    Y = T.match_buffer(b, [224, 224], dtype="float32")
-    cache = T.alloc_buffer([224, 224], dtype="float32")
-    for h, w in T.grid(224, 224):
-        with T.block("cache"):
-            cache[h, w] = X[h, w]
-    for h, w, kh, kw in T.grid(224, 224, 3, 3):
-        with T.block("compute"):
-            Y[h, w] = T.max(
-                Y[h, w],
-                T.if_then_else(
-                    T.likely(1 <= h + kh, dtype="bool")
-                    and T.likely(h + kh < 225, dtype="bool")
-                    and T.likely(1 <= w + kw, dtype="bool")
-                    and T.likely(w + kw < 225, dtype="bool"),
-                    cache[h + kh - 1, w + kw - 1],
-                    0.0,
-                    dtype="float32",
-                ),
-            )
-
-
[email protected]_func
-def compacted_padding_pattern_inlined(
-    X: T.Buffer((224, 224), "float32"), Y: T.Buffer((224, 224), "float32")
-) -> None:
-    cache = T.alloc_buffer([224, 224], dtype="float32")
-    for h, w in T.grid(224, 224):
-        with T.block("cache"):
-            cache[h, w] = X[h, w]
-    for h, w, kh, kw in T.grid(224, 224, 3, 3):
-        with T.block("compute"):
-            Y[h, w] = T.max(
-                Y[h, w],
-                T.if_then_else(
-                    T.likely(1 <= h + kh, dtype="bool")
-                    and T.likely(h + kh < 225, dtype="bool")
-                    and T.likely(1 <= w + kw, dtype="bool")
-                    and T.likely(w + kw < 225, dtype="bool"),
-                    cache[h + kh - 1, w + kw - 1],
-                    0.0,
-                    dtype="float32",
-                ),
-            )
-
-
[email protected]_func
-def mem_access_in_branch_func(a: T.handle) -> None:
-    A = T.match_buffer(a, (224, 224), "float32")
-    with T.block():
-        B1 = T.alloc_buffer((224, 224), dtype="float32")
-        B2 = T.alloc_buffer((224, 224), dtype="float32")
-        B3 = T.alloc_buffer((224, 224), dtype="float32")
-        B4 = T.alloc_buffer((224, 224), dtype="float32")
-        for i in range(0, 224):
-            for j in range(0, 224):
+    @T.prim_func
+    def expected(a: T.handle) -> None:
+        A = T.match_buffer(a, [224, 224], dtype="float32")
+        with T.block():
+            B1 = T.alloc_buffer([112, 112], dtype="float32")
+            B2 = T.alloc_buffer([224, 224], dtype="float32")
+            B3 = T.alloc_buffer([224, 224], dtype="float32")
+            B4 = T.alloc_buffer([112, 112], dtype="float32")
+            for i, j in T.grid(224, 224):
                 with T.block():
                     if i < 112 and j < 112:
                         B1[i, j] = A[i, j] * 2.0
                     else:
                         B2[i, j] = A[i, j] + 3.0
-        for i in range(0, 224):
-            for j in range(0, 224):
+            for i, j in T.grid(224, 224):
                 with T.block():
                     if i < 112 or j < 112:
                         B3[i, j] = A[i, j] * 2.0
                     else:
-                        B4[i, j] = A[i, j] + 3.0
-
-
[email protected]_func
-def compacted_mem_access_in_branch_func(a: T.handle) -> None:
-    A = T.match_buffer(a, [224, 224], dtype="float32")
-    with T.block():
-        B1 = T.alloc_buffer([112, 112], dtype="float32")
-        B2 = T.alloc_buffer([224, 224], dtype="float32")
-        B3 = T.alloc_buffer([224, 224], dtype="float32")
-        B4 = T.alloc_buffer([112, 112], dtype="float32")
-        for i, j in T.grid(224, 224):
-            with T.block():
-                if i < 112 and j < 112:
-                    B1[i, j] = A[i, j] * 2.0
-                else:
-                    B2[i, j] = A[i, j] + 3.0
-        for i, j in T.grid(224, 224):
-            with T.block():
-                if i < 112 or j < 112:
-                    B3[i, j] = A[i, j] * 2.0
-                else:
-                    B4[i - 112, j - 112] = A[i, j] + 3.0
-
-
[email protected]_func
-def opaque_access_annotated_func(a: T.handle) -> None:
-    A = T.match_buffer(a, (1024,), "float32")
-    with T.block():
-        B = T.alloc_buffer((1024,), dtype="float32")
-        C = T.alloc_buffer((1024,), dtype="float32")
-        for i in range(0, 512):
-            with T.block():
-                # no annotation, opaque access will cover full region
-                T.reads([])
-                T.writes([])
-                T.evaluate(T.call_extern("opaque_extern_function", A.data, 
B.data, dtype="int32"))
-                B[i] = A[i]
-            with T.block():
-                # treat opaque access only access annotated regions, even if
-                # they are not compatible with actual buffer accesses.
-                T.reads([B[i]])
-                T.writes([C[i : i + 9]])
-                T.evaluate(T.call_extern("opaque_extern_function", B.data, 
C.data, dtype="int32"))
-                C[i] = B[i]
-
-
[email protected]_func
-def compacted_opaque_access_annotated_func(a: T.handle) -> None:
-    A = T.match_buffer(a, (1024,), "float32")
-    with T.block():
-        B = T.alloc_buffer((1024,), dtype="float32")
-        C = T.alloc_buffer((520,), dtype="float32")
-        for i in range(0, 512):
-            with T.block():
-                # no annotation, opaque access will cover full region
-                T.reads([])
-                T.writes([])
-                T.evaluate(T.call_extern("opaque_extern_function", A.data, 
B.data, dtype="int32"))
-                B[i] = A[i]
-            with T.block():
-                # treat opaque access only access annotated regions, even if
-                # they are not compatible with actual buffer accesses.
-                T.reads([B[i]])
-                T.writes([C[i : i + 9]])
-                T.evaluate(T.call_extern("opaque_extern_function", B.data, 
C.data, dtype="int32"))
-                C[i] = B[i]
-
-
[email protected]_func
-def sparse_read_cache(
-    A_data: T.Buffer((819,), "float32"),
-    B: T.Buffer((128,), "float32"),
-    A_indptr: T.Buffer((129,), "int32"),
-    A_indices: T.Buffer((819,), "int32"),
-) -> None:
-    for i in T.serial(128):
-        with T.block("rowsum_outer"):
-            T.reads(
-                A_indptr[i : i + 1],
-                A_data[A_indptr[i] + 0 : A_indptr[i] + (A_indptr[i + 1] - 
A_indptr[i])],
-            )
-            T.writes(B[i])
-            with T.block("rowsum_init"):
-                T.reads()
-                T.writes(B[i])
-                B[i] = T.float32(0)
-            for k in T.serial(A_indptr[i + 1] - A_indptr[i]):
-                with T.block():
-                    T.reads(A_indptr[i], A_data[A_indptr[i] + k], B[i])
-                    T.writes(B[i])
-                    A_data_local = T.alloc_buffer([819], dtype="float32", 
scope="local")
-                    with T.block("A_data_cache_read"):
-                        T.reads(A_indptr[i], A_data[A_indptr[i] + k])
-                        T.writes(A_data_local[A_indptr[i] + k])
-                        A_data_local[A_indptr[i] + k] = A_data[A_indptr[i] + k]
-                    with T.block("rowsum_inner"):
-                        T.reads(B[i], A_indptr[i], A_data[A_indptr[i] + k])
-                        T.writes(B[i])
-                        B[i] = B[i] + A_data_local[A_indptr[i] + k]
-
-
[email protected]_func
-def compacted_sparse_read_cache(
-    A_data: T.Buffer((819,), "float32"),
-    B: T.Buffer((128,), "float32"),
-    A_indptr: T.Buffer((129,), "int32"),
-    A_indices: T.Buffer((819,), "int32"),
-) -> None:
-    for i in T.serial(128):
-        with T.block("rowsum_outer"):
-            T.reads(
-                A_indptr[i : i + 1],
-                A_data[A_indptr[i] + 0 : A_indptr[i] + 0 + (A_indptr[i + 1] - 
A_indptr[i])],
-            )
-            T.writes(B[i])
-            with T.block("rowsum_init"):
-                T.reads()
-                T.writes(B[i])
-                B[i] = T.float32(0)
-            for k in T.serial(A_indptr[i + 1] - A_indptr[i]):
-                with T.block():
-                    T.reads(A_indptr[i], A_data[A_indptr[i] + k], B[i])
-                    T.writes(B[i])
-                    A_data_local = T.alloc_buffer([1], dtype="float32", 
scope="local")
-                    with T.block("A_data_cache_read"):
-                        T.reads(A_indptr[i], A_data[A_indptr[i] + k])
-                        T.writes(A_data_local[T.min(A_indptr[i] + k, 0)])
-                        A_data_local[T.min(A_indptr[i] + k, 0)] = 
A_data[A_indptr[i] + k]
-                    with T.block("rowsum_inner"):
-                        T.reads(B[i], A_indptr[i], A_data[A_indptr[i] + k])
-                        T.writes(B[i])
-                        B[i] = B[i] + A_data_local[T.min(A_indptr[i] + k, 0)]
-
-
[email protected]_func
-def narrow_shape(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32")) 
-> None:
-    B_cache = T.alloc_buffer(10, "float32")
-    for j in T.serial(3):
-        for k in T.serial(4):
-            with T.block("B_cache"):
-                T.where(j * 4 + k < 10)
-                B_cache[j * 4 + k] = B[j]
-    for i in T.serial(10):
-        A[i] = B_cache[i] + T.float32(1)
-
-
[email protected]_func
-def compacted_narrow_shape(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), 
"float32")) -> None:
-    # body
-    # with T.block("root")
-    B_cache = T.alloc_buffer([10], dtype="float32")
-    for j, k in T.grid(3, 4):
-        with T.block("B_cache"):
-            T.where(j * 4 + k < 10)
-            T.reads(B[j])
-            T.writes(B_cache[j * 4 + k])
-            B_cache[j * 4 + k] = B[j]
-    for i in T.serial(10):
-        A[i] = B_cache[i] + T.float32(1)
-
-
-def test_elementwise():
-    _check(elementwise_func, compacted_elementwise_func)
-
-
-def test_unschedulable_block():
-    _check(unschedulable_func, unschedulable_func)  # changes nothing
-
-
-def test_param_access():
-    _check(param_buffer_access_func, param_buffer_access_func)  # changes 
nothing
-
-
-def test_shared_mem():
-    _check(shared_mem_func, compacted_shared_mem_func)
-
-
-def test_warp_mem():
-    _check(warp_mem_func, compacted_warp_mem_func)
-
-
-def test_symbolic():
-    _check(symbolic_func, compacted_symbolic_func)
-
-
-def test_complex():
-    _check(complex_func, compacted_complex_func)
+                        B4[i - 112, j - 112] = A[i, j] + 3.0
 
 
-def test_match_buffer():
-    _check(match_buffer_func, compacted_match_buffer_func)
+class TestAnnotatedOpaqueAccess(BaseCompactTest):
 
+    is_lower_order_free = False
 
-def test_lower_te():
-    x = te.placeholder((1,))
-    y = te.compute((1,), lambda i: x[i] + 2)
-    s = te.create_schedule(y.op)
-    orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
-    mod = tvm.tir.transform.CompactBufferAllocation()(orig_mod)
-    tvm.ir.assert_structural_equal(mod, orig_mod)  # CompactBufferAllocation 
should do nothing on TE
-
-
-def test_storage_align():
-    _check(storage_align_func, compacted_storage_align_func)
-
+    @T.prim_func
+    def before(a: T.handle) -> None:
+        A = T.match_buffer(a, (1024,), "float32")
+        with T.block():
+            B = T.alloc_buffer((1024,), dtype="float32")
+            C = T.alloc_buffer((1024,), dtype="float32")
+            for i in range(0, 512):
+                with T.block():
+                    # no annotation, opaque access will cover full region
+                    T.reads([])
+                    T.writes([])
+                    T.evaluate(
+                        T.call_extern("opaque_extern_function", A.data, 
B.data, dtype="int32")
+                    )
+                    B[i] = A[i]
+                with T.block():
+                    # treat opaque access only access annotated regions, even 
if
+                    # they are not compatible with actual buffer accesses.
+                    T.reads([B[i]])
+                    T.writes([C[i : i + 9]])
+                    T.evaluate(
+                        T.call_extern("opaque_extern_function", B.data, 
C.data, dtype="int32")
+                    )
+                    C[i] = B[i]
 
-def test_padding_pattern():
-    _check(padding_pattern_func, compacted_padding_pattern_func)
+    @T.prim_func
+    def expected(a: T.handle) -> None:
+        A = T.match_buffer(a, (1024,), "float32")
+        with T.block():
+            B = T.alloc_buffer((1024,), dtype="float32")
+            C = T.alloc_buffer((520,), dtype="float32")
+            for i in range(0, 512):
+                with T.block():
+                    # no annotation, opaque access will cover full region
+                    T.reads([])
+                    T.writes([])
+                    T.evaluate(
+                        T.call_extern("opaque_extern_function", A.data, 
B.data, dtype="int32")
+                    )
+                    B[i] = A[i]
+                with T.block():
+                    # treat opaque access only access annotated regions, even 
if
+                    # they are not compatible with actual buffer accesses.
+                    T.reads([B[i]])
+                    T.writes([C[i : i + 9]])
+                    T.evaluate(
+                        T.call_extern("opaque_extern_function", B.data, 
C.data, dtype="int32")
+                    )
+                    C[i] = B[i]
 
 
-def test_padding_pattern_inlined():
-    _check(padding_pattern_inlined, compacted_padding_pattern_inlined)
+class TestSparseReadCache(BaseCompactTest):
+    @T.prim_func
+    def before(
+        A_data: T.Buffer((819,), "float32"),
+        B: T.Buffer((128,), "float32"),
+        A_indptr: T.Buffer((129,), "int32"),
+    ) -> None:
+        for i in T.serial(128):
+            with T.block("rowsum_outer"):
+                T.reads(
+                    A_indptr[i : i + 1],
+                    A_data[A_indptr[i] + 0 : A_indptr[i] + (A_indptr[i + 1] - 
A_indptr[i])],
+                )
+                T.writes(B[i])
+                with T.block("rowsum_init"):
+                    T.reads()
+                    T.writes(B[i])
+                    B[i] = T.float32(0)
+                for k in T.serial(A_indptr[i + 1] - A_indptr[i]):
+                    with T.block():
+                        T.reads(A_indptr[i], A_data[A_indptr[i] + k], B[i])
+                        T.writes(B[i])
+                        A_data_local = T.alloc_buffer([819], dtype="float32", 
scope="local")
+                        with T.block("A_data_cache_read"):
+                            T.reads(A_indptr[i], A_data[A_indptr[i] + k])
+                            T.writes(A_data_local[A_indptr[i] + k])
+                            A_data_local[A_indptr[i] + k] = A_data[A_indptr[i] 
+ k]
+                        with T.block("rowsum_inner"):
+                            T.reads(B[i], A_indptr[i], A_data[A_indptr[i] + k])
+                            T.writes(B[i])
+                            B[i] = B[i] + A_data_local[A_indptr[i] + k]
 
+    @T.prim_func
+    def expected(
+        A_data: T.Buffer((819,), "float32"),
+        B: T.Buffer((128,), "float32"),
+        A_indptr: T.Buffer((129,), "int32"),
+    ) -> None:
+        for i in T.serial(128):
+            with T.block("rowsum_outer"):
+                T.reads(
+                    A_indptr[i : i + 1],
+                    A_data[A_indptr[i] + 0 : A_indptr[i] + 0 + (A_indptr[i + 
1] - A_indptr[i])],
+                )
+                T.writes(B[i])
+                with T.block("rowsum_init"):
+                    T.reads()
+                    T.writes(B[i])
+                    B[i] = T.float32(0)
+                for k in T.serial(A_indptr[i + 1] - A_indptr[i]):
+                    with T.block():
+                        T.reads(A_indptr[i], A_data[A_indptr[i] + k], B[i])
+                        T.writes(B[i])
+                        A_data_local = T.alloc_buffer([1], dtype="float32", 
scope="local")
+                        with T.block("A_data_cache_read"):
+                            T.reads(A_indptr[i], A_data[A_indptr[i] + k])
+                            T.writes(A_data_local[T.min(A_indptr[i] + k, 0)])
+                            A_data_local[T.min(A_indptr[i] + k, 0)] = 
A_data[A_indptr[i] + k]
+                        with T.block("rowsum_inner"):
+                            T.reads(B[i], A_indptr[i], A_data[A_indptr[i] + k])
+                            T.writes(B[i])
+                            B[i] = B[i] + A_data_local[T.min(A_indptr[i] + k, 
0)]
 
-def test_mem_access_in_branch_func():
-    _check(mem_access_in_branch_func, compacted_mem_access_in_branch_func)
 
+class TestDataDependentRegion(BaseCompactTest):
+    """Partial code of NMS, the `argsort_nms_cpu`'s region depends on inner 
allocated buffer
+    `nkeep`'s value, thus the buffer should not be compacted with data 
dependent region extent."""
 
-def test_opaque_access_annotated_func():
-    _check(opaque_access_annotated_func, 
compacted_opaque_access_annotated_func)
+    @T.prim_func
+    def before(
+        p0: T.Buffer((30,), "float32"),
+        p1: T.Buffer((1,), "int32"),
+        hybrid_nms: T.Buffer((30,), "float32"),
+    ):
+        argsort_nms_cpu = T.decl_buffer([5], "int32", scope="global")
+        for i in range(1):
+            nkeep = T.decl_buffer([1], "int32", scope="global")
+            if 0 < p1[i]:
+                nkeep[0] = p1[i]
+                if 2 < nkeep[0]:
+                    nkeep[0] = 2
+                for j in T.parallel(nkeep[0]):
+                    for k in range(6):
+                        hybrid_nms[i * 30 + j * 6 + k] = p0[
+                            i * 30 + argsort_nms_cpu[i * 5 + j] * 6 + k
+                        ]
+                    hybrid_nms[i * 5 + j] = argsort_nms_cpu[i * 5 + j]
+                if 2 < p1[i]:
+                    for j in T.parallel(p1[i] - nkeep[0]):
+                        for k in range(6):
+                            hybrid_nms[i * 30 + j * 6 + nkeep[0] * 6 + k] = 
T.float32(-1)
+                        hybrid_nms[i * 5 + j + nkeep[0]] = -1
 
+    expected = before
 
-def test_sparse_read_cache():
-    _check(sparse_read_cache, compacted_sparse_read_cache)
 
+class TestNarrowShape(BaseCompactTest):
+    @T.prim_func
+    def before(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32")) 
-> None:
+        B_cache = T.alloc_buffer(10, "float32")
+        for j in T.serial(3):
+            for k in T.serial(4):
+                with T.block("B_cache"):
+                    T.where(j * 4 + k < 10)
+                    B_cache[j * 4 + k] = B[j]
+        for i in T.serial(10):
+            A[i] = B_cache[i] + T.float32(1)
 
-def test_narrow_shape():
-    _check(narrow_shape, compacted_narrow_shape)
+    @T.prim_func
+    def expected(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32")) 
-> None:
+        B_cache = T.alloc_buffer([10], dtype="float32")
+        for j, k in T.grid(3, 4):
+            with T.block("B_cache"):
+                T.where(j * 4 + k < 10)
+                T.reads(B[j])
+                T.writes(B_cache[j * 4 + k])
+                B_cache[j * 4 + k] = B[j]
+        for i in T.serial(10):
+            A[i] = B_cache[i] + T.float32(1)
 
 
-def test_compact_with_let_binding():
+class TestLetBinding(BaseCompactTest):
     @T.prim_func
-    def func_with_let_binding():
+    def before():
         A = T.alloc_buffer((64, 8), "float32")
         B = T.alloc_buffer((64, 8), "float32")
         C = T.alloc_buffer((8, 8), "float32")
@@ -735,10 +758,12 @@ def test_compact_with_let_binding():
                 rjj: T.int32 = riijj % 8
                 C[rii, rjj] += A[rk, rii] * B[rk, rjj]
 
-    _check(func_with_let_binding, func_with_let_binding)
+    expected = before
+
 
+class TestNonIndexLetBinding(BaseCompactTest):
     @T.prim_func
-    def func_with_non_index_let_binding():
+    def before():
         A = T.alloc_buffer((64), "float32")
         x1 = T.call_extern("get", dtype="float16")
         x2 = T.call_extern("get", dtype="float32")
@@ -750,14 +775,12 @@ def test_compact_with_let_binding():
         for rk in range(64):
             A[rk] = T.call_extern("load_ptr", x1, x2, x3, x4, x5, x6, x7, 
dtype="float32")
 
-    _check(func_with_non_index_let_binding, func_with_non_index_let_binding)
+    expected = before
 
 
-def test_compact_spatial_tiled_pad_and_pooling():
+class TestSpatialTiledPadPooling(BaseCompactTest):
     @T.prim_func
-    def spatial_tiled_pad_and_pooling(
-        X: T.Buffer((64, 112, 112), "int32"), Y: T.Buffer((64, 56, 56), 
"int32")
-    ) -> None:
+    def before(X: T.Buffer((64, 112, 112), "int32"), Y: T.Buffer((64, 56, 56), 
"int32")) -> None:
         for h_o, w_o in T.grid(14, 14):
             with T.block():
                 X_cache = T.alloc_buffer([112, 112, 64], dtype="int32")
@@ -795,9 +818,7 @@ def test_compact_spatial_tiled_pad_and_pooling():
                         )
 
     @T.prim_func
-    def compacted_spatial_tiled_pad_and_pooling(
-        X: T.Buffer((64, 112, 112), "int32"), Y: T.Buffer((64, 56, 56), 
"int32")
-    ) -> None:
+    def expected(X: T.Buffer((64, 112, 112), "int32"), Y: T.Buffer((64, 56, 
56), "int32")) -> None:
         for h_o, w_o in T.grid(14, 14):
             with T.block():
                 T.reads(X[0:64, h_o * 8 - 1 : h_o * 8 + 8, w_o * 8 - 1 : w_o * 
8 + 8])
@@ -846,15 +867,13 @@ def test_compact_spatial_tiled_pad_and_pooling():
                             ),
                         )
 
-    _check(spatial_tiled_pad_and_pooling, 
compacted_spatial_tiled_pad_and_pooling)
 
-
-def test_complex_case_1():
+class TestComplexCase1(BaseCompactTest):
     """Meta-schedule matmul case for compact shared A, B matrix"""
 
     # fmt: off
     @T.prim_func
-    def func(A: T.Buffer((960, 770), "float32"), B: T.Buffer((770, 2304), 
"float32"), C: T.Buffer((960, 2304), "float32")) -> None:
+    def before(A: T.Buffer((960, 770), "float32"), B: T.Buffer((770, 2304), 
"float32"), C: T.Buffer((960, 2304), "float32")) -> None:
         for bx in T.thread_binding(144, thread="blockIdx.x"):
             for vx in T.thread_binding(2, thread="vthread.x"):
                 for tx_p in T.thread_binding(256, thread="threadIdx.x"):
@@ -880,7 +899,7 @@ def test_complex_case_1():
                                         C[(((bx // 18 + 0) * 8 + tx_p // 32) * 
8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] 
= C[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx 
% 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] + A_shared[(((bx // 18 + 0) * 8 + tx_p 
// 32) * 8 + i_3) * 2 + i_4, (k_0 + k_1) * 4 + k_2] * B_shared[(k_0 + k_1) * 4 
+ k_2, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4]
 
     @T.prim_func
-    def compacted_func(A: T.Buffer((960, 770), "float32"), B: T.Buffer((770, 
2304), "float32"), C: T.Buffer((960, 2304), "float32")) -> None:
+    def expected(A: T.Buffer((960, 770), "float32"), B: T.Buffer((770, 2304), 
"float32"), C: T.Buffer((960, 2304), "float32")) -> None:
         for bx in T.thread_binding(144, thread="blockIdx.x"):
             for vx in T.thread_binding(2, thread="vthread.x"):
                 for tx_p in T.thread_binding(256, thread="threadIdx.x"):
@@ -906,14 +925,13 @@ def test_complex_case_1():
                                         C[bx // 18 * 128 + tx_p // 32 * 16 + 
i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2 + j_4] = C[bx // 18 * 
128 + tx_p // 32 * 16 + i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2 
+ j_4] + A_shared[tx_p // 32 * 16 + i_3 * 2 + i_4, k_2] * B_shared[k_2, vx * 64 
+ tx_p % 32 * 2 + j_4]
     # fmt: on
 
-    _check(func, compacted_func)
-
 
-def test_compact_dependent_buffer_indices():
+class TestDependentBufferIndices(BaseCompactTest):
     """Check the upper bound on different indices could be independently 
estimated."""
 
     @T.prim_func
-    def diagonal_access():
+    def before():
+        """This is a diagnal buffer access pattern"""
         for i in range(8):
             with T.block():
                 A = T.alloc_buffer((256, 256), "float32")
@@ -923,7 +941,7 @@ def test_compact_dependent_buffer_indices():
                         A[i * 64 + j * 8 + k, i * 64 + j * 8 + k] = 1.0
 
     @T.prim_func
-    def diagonal_access_compacted() -> None:
+    def expected() -> None:
         for i in T.serial(8):
             with T.block():
                 A = T.alloc_buffer([60, 60], dtype="float32")
@@ -932,14 +950,12 @@ def test_compact_dependent_buffer_indices():
                         T.where(j * 8 + k < 60)
                         A[j * 8 + k, j * 8 + k] = 1.0
 
-    _check(diagonal_access, diagonal_access_compacted)
 
-
-def test_compact_dependent_buffer_indices_of_packed_matmul():
+class TestDependentBufferIndicesOfPackedMatmul(BaseCompactTest):
     """Check the outer dimension of the packed M-dim should be compacted to 1 
wrt split condition."""
 
     @T.prim_func
-    def nonuniform_packed_matmul_write_cache(
+    def before(
         A: T.Buffer((1020, 64), "float32"),
         B: T.Buffer((1000, 64), "float32"),
         C: T.Buffer((1020, 1000), "float32"),
@@ -976,7 +992,7 @@ def 
test_compact_dependent_buffer_indices_of_packed_matmul():
                         ]
 
     @T.prim_func
-    def nonuniform_packed_matmul_write_cache_compacted(
+    def expected(
         A: T.Buffer((1020, 64), "float32"),
         B: T.Buffer((1000, 64), "float32"),
         C: T.Buffer((1020, 1000), "float32"),
@@ -1006,90 +1022,338 @@ def 
test_compact_dependent_buffer_indices_of_packed_matmul():
                             (ax0 * 16 + ax1) % 255 % 16,
                         ]
 
-    _check(nonuniform_packed_matmul_write_cache, 
nonuniform_packed_matmul_write_cache_compacted)
 
+class TestTileAwareCompaction(BaseCompactTest):
+    """Each partitioned tile could be independently compacted."""
 
-def test_compact_symbolic_bound0():
-    """Test symbolic bound that get compacted to constant"""
+    # it is not an opaque block case intentionally
+    is_lower_order_free = False
+
+    @T.prim_func
+    def before(
+        A: T.Buffer((128, 128), "float32"),
+        B: T.Buffer((128, 128), "float32"),
+        C: T.Buffer((128, 128), "float32"),
+    ):
+        for i_0 in range(5, annotations={"pragma_loop_partition_hint": 1}):
+            for j_0 in range(5, annotations={"pragma_loop_partition_hint": 1}):
+                A_local = T.decl_buffer((26, 128), scope="local")
+                B_local = T.decl_buffer((128, 26), scope="local")
+                C_local = T.decl_buffer((26, 26), scope="local")
+                for ax0, ax1 in T.grid(26, 128):
+                    if i_0 * 26 + ax0 < 128:
+                        A_local[ax0, ax1] = A[i_0 * 26 + ax0, ax1]
+                for ax0, ax1 in T.grid(128, 26):
+                    if j_0 * 26 + ax1 < 128:
+                        B_local[ax0, ax1] = B[ax0, j_0 * 26 + ax1]
+                for i_1, j_1, k in T.grid(26, 26, 128):
+                    if i_0 * 26 + i_1 < 128 and j_0 * 26 + j_1 < 128:
+                        if k == 0:
+                            C_local[i_1, j_1] = T.float32(0)
+                        C_local[i_1, j_1] = C_local[i_1, j_1] + A_local[i_1, 
k] * B_local[k, j_1]
+                for ax0, ax1 in T.grid(26, 26):
+                    if i_0 * 26 + ax0 < 128 and j_0 * 26 + ax1 < 128:
+                        C[i_0 * 26 + ax0, j_0 * 26 + ax1] = C_local[ax0, ax1]
+
+    # Get partitioned workload to compact
+    before_mod = tvm.IRModule.from_expr(before)
+    with tvm.transform.PassContext(config={"tir.LoopPartition": 
{"partition_const_loop": True}}):
+        before_mod = tvm.tir.transform.LowerOpaqueBlock()(before_mod)
+        before_mod = tvm.tir.transform.LoopPartition()(before_mod)
+    before = before_mod["main"]
+
+    @T.prim_func
+    def expected(
+        A: T.Buffer((128, 128), "float32"),
+        B: T.Buffer((128, 128), "float32"),
+        C: T.Buffer((128, 128), "float32"),
+    ):
+        for i_0 in range(4):
+            for j_0 in range(4):
+                A_local_tile0 = T.decl_buffer((26, 128), scope="local")
+                B_local_tile0 = T.decl_buffer((128, 26), scope="local")
+                C_local_tile0 = T.decl_buffer((26, 26), scope="local")
+                for ax0, ax1 in T.grid(26, 128):
+                    A_local_tile0[ax0, ax1] = A[i_0 * 26 + ax0, ax1]
+                for ax0, ax1 in T.grid(128, 26):
+                    B_local_tile0[ax0, ax1] = B[ax0, j_0 * 26 + ax1]
+                for i_1, j_1, k in T.grid(26, 26, 128):
+                    if k == 0:
+                        C_local_tile0[i_1, j_1] = T.float32(0)
+                    C_local_tile0[i_1, j_1] = (
+                        C_local_tile0[i_1, j_1] + A_local_tile0[i_1, k] * 
B_local_tile0[k, j_1]
+                    )
+                for ax0, ax1 in T.grid(26, 26):
+                    C[i_0 * 26 + ax0, j_0 * 26 + ax1] = C_local_tile0[ax0, ax1]
+
+            A_local_tile1 = T.decl_buffer((26, 128), scope="local")
+            B_local_tile1 = T.decl_buffer((128, 24), scope="local")
+            C_local_tile1 = T.decl_buffer((26, 24), scope="local")
+            for ax0, ax1 in T.grid(26, 128):
+                A_local_tile1[ax0, ax1] = A[i_0 * 26 + ax0, ax1]
+            for ax0, ax1 in T.grid(128, 26):
+                if ax1 < 24:
+                    B_local_tile1[ax0, ax1] = B[ax0, ax1 + 104]
+            for i_1, j_1, k in T.grid(26, 26, 128):
+                if j_1 < 24:
+                    if k == 0:
+                        C_local_tile1[i_1, j_1] = T.float32(0)
+                    C_local_tile1[i_1, j_1] = (
+                        C_local_tile1[i_1, j_1] + A_local_tile1[i_1, k] * 
B_local_tile1[k, j_1]
+                    )
+            for ax0, ax1 in T.grid(26, 26):
+                if ax1 < 24:
+                    C[i_0 * 26 + ax0, ax1 + 104] = C_local_tile1[ax0, ax1]
+
+        for j_0 in range(4):
+            A_local_tile2 = T.decl_buffer((24, 128), scope="local")
+            B_local_tile2 = T.decl_buffer((128, 26), scope="local")
+            C_local_tile2 = T.decl_buffer((24, 26), scope="local")
+            for ax0, ax1 in T.grid(26, 128):
+                if ax0 < 24:
+                    A_local_tile2[ax0, ax1] = A[ax0 + 104, ax1]
+            for ax0, ax1 in T.grid(128, 26):
+                B_local_tile2[ax0, ax1] = B[ax0, j_0 * 26 + ax1]
+            for i_1, j_1, k in T.grid(26, 26, 128):
+                if i_1 < 24:
+                    if k == 0:
+                        C_local_tile2[i_1, j_1] = T.float32(0)
+                    C_local_tile2[i_1, j_1] = (
+                        C_local_tile2[i_1, j_1] + A_local_tile2[i_1, k] * 
B_local_tile2[k, j_1]
+                    )
+            for ax0, ax1 in T.grid(26, 26):
+                if ax0 < 24:
+                    C[ax0 + 104, j_0 * 26 + ax1] = C_local_tile2[ax0, ax1]
+
+        A_local_tile3 = T.decl_buffer((24, 128), scope="local")
+        B_local_tile3 = T.decl_buffer((128, 24), scope="local")
+        C_local_tile3 = T.decl_buffer((24, 24), scope="local")
+        for ax0, ax1 in T.grid(26, 128):
+            if ax0 < 24:
+                A_local_tile3[ax0, ax1] = A[ax0 + 104, ax1]
+        for ax0, ax1 in T.grid(128, 26):
+            if ax1 < 24:
+                B_local_tile3[ax0, ax1] = B[ax0, ax1 + 104]
+        for i_1, j_1, k in T.grid(26, 26, 128):
+            if i_1 < 24 and j_1 < 24:
+                if k == 0:
+                    C_local_tile3[i_1, j_1] = T.float32(0)
+                C_local_tile3[i_1, j_1] = (
+                    C_local_tile3[i_1, j_1] + A_local_tile3[i_1, k] * 
B_local_tile3[k, j_1]
+                )
+        for ax0, ax1 in T.grid(26, 26):
+            if ax0 < 24 and ax1 < 24:
+                C[ax0 + 104, ax1 + 104] = C_local_tile3[ax0, ax1]
+
+
+class TestNonStrictCompactionForPaddedMatmul(BaseCompactTest):
+
+    is_strict_mode = False
+
+    @T.prim_func
+    def before(
+        A: T.Buffer((127, 127), "float32"),
+        B: T.Buffer((127, 127), "float32"),
+        C: T.Buffer((127, 127), "float32"),
+    ):
+        """A mock workload where the intermediate buffer allocation is not 
enought originally"""
+        for i_0, j_0 in T.grid(4, 4):
+            with T.block(""):
+                T.reads(A[i_0 * 32 : i_0 * 32 + 32, 0:128], B[0:128, j_0 * 32 
: j_0 * 32 + 32])
+                T.writes(C[i_0 * 32 : i_0 * 32 + 32, j_0 * 32 : j_0 * 32 + 32])
+                A_local = T.alloc_buffer((127, 127), scope="local")
+                B_local = T.alloc_buffer((127, 127), scope="local")
+                C_local = T.alloc_buffer((127, 127), scope="local")
+                for ax0, ax1 in T.grid(32, 128):
+                    with T.block("A_local"):
+                        A_local[i_0 * 32 + ax0, ax1] = T.if_then_else(
+                            i_0 * 32 + ax0 < 127, A[i_0 * 32 + ax0, ax1], 0.0
+                        )
+                for ax0, ax1 in T.grid(128, 32):
+                    with T.block("B_local"):
+                        B_local[ax0, j_0 * 32 + ax1] = T.if_then_else(
+                            j_0 * 32 + ax1 < 127, B[ax0, j_0 * 32 + ax1], 0.0
+                        )
+                for i_1, j_1, k in T.grid(32, 32, 128):
+                    with T.block("compute"):
+                        T.where(i_0 * 32 + i_1 < 127 and j_0 * 32 + j_1 < 127)
+                        if k == 0:
+                            C_local[i_0 * 32 + i_1, j_0 * 32 + j_1] = 
T.float32(0)
+                        C_local[i_0 * 32 + i_1, j_0 * 32 + j_1] = (
+                            C_local[i_0 * 32 + i_1, j_0 * 32 + j_1]
+                            + A_local[i_0 * 32 + i_1, k] * B_local[k, j_0 * 32 
+ j_1]
+                        )
+                for ax0, ax1 in T.grid(32, 32):
+                    with T.block("C_local"):
+                        T.where(i_0 * 32 + ax0 < 127 and j_0 * 32 + ax1 < 127)
+                        C[i_0 * 32 + ax0, j_0 * 32 + ax1] = C_local[i_0 * 32 + 
ax0, j_0 * 32 + ax1]
+
+    @T.prim_func
+    def expected(
+        A: T.Buffer((127, 127), "float32"),
+        B: T.Buffer((127, 127), "float32"),
+        C: T.Buffer((127, 127), "float32"),
+    ):
+        for i_0, j_0 in T.grid(4, 4):
+            with T.block(""):
+                T.reads(A[i_0 * 32 : i_0 * 32 + 32, 0:128], B[0:128, j_0 * 32 
: j_0 * 32 + 32])
+                T.writes(C[i_0 * 32 : i_0 * 32 + 32, j_0 * 32 : j_0 * 32 + 32])
+                A_local = T.alloc_buffer((32, 128), scope="local")
+                B_local = T.alloc_buffer((128, 32), scope="local")
+                C_local = T.alloc_buffer((32, 32), scope="local")
+                for ax0, ax1 in T.grid(32, 128):
+                    with T.block("A_local"):
+                        A_local[ax0, ax1] = T.if_then_else(
+                            i_0 * 32 + ax0 < 127, A[i_0 * 32 + ax0, ax1], 
T.float32(0)
+                        )
+                for ax0, ax1 in T.grid(128, 32):
+                    with T.block("B_local"):
+                        B_local[ax0, ax1] = T.if_then_else(
+                            j_0 * 32 + ax1 < 127, B[ax0, j_0 * 32 + ax1], 
T.float32(0)
+                        )
+                for i_1, j_1, k in T.grid(32, 32, 128):
+                    with T.block("compute"):
+                        T.where(i_0 * 32 + i_1 < 127 and j_0 * 32 + j_1 < 127)
+                        if k == 0:
+                            C_local[i_1, j_1] = T.float32(0)
+                        C_local[i_1, j_1] = C_local[i_1, j_1] + A_local[i_1, 
k] * B_local[k, j_1]
+                for ax0, ax1 in T.grid(32, 32):
+                    with T.block("C_local"):
+                        T.where(i_0 * 32 + ax0 < 127 and j_0 * 32 + ax1 < 127)
+                        C[i_0 * 32 + ax0, j_0 * 32 + ax1] = C_local[ax0, ax1]
+
+
+class TestNotCompactAliasBuffer(BaseCompactTest):
 
-    @tvm.script.ir_module
-    class Before:
-        @T.prim_func
-        def main(x: T.handle, y: T.handle, n: T.int64):
-            X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
-            Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
-            for i, k_0 in T.grid(T.int64(8), n):
-                with T.block(""):
-                    X_global = T.alloc_buffer((T.int64(8), n * T.int64(32)))
-                    for ax0 in range(T.int64(32)):
-                        with T.block("X_global"):
-                            X_global[i, k_0 * T.int64(32) + ax0] = X[i, k_0 * 
T.int64(32) + ax0]
-                    for k_1 in range(T.int64(32)):
-                        with T.block("Y"):
-                            Y[i, k_0 * T.int64(32) + k_1] = X_global[i, k_0 * 
T.int64(32) + k_1]
-
-    @tvm.script.ir_module
-    class Expected:
-        @T.prim_func
-        def main(x: T.handle, y: T.handle, n: T.int64):
-            X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
-            Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
-            for i, k_0 in T.grid(T.int64(8), n):
-                with T.block(""):
-                    X_global = T.alloc_buffer((T.int64(1), T.int64(32)))
-                    for ax0 in range(T.int64(32)):
-                        with T.block("X_global"):
-                            X_global[T.int64(0), ax0] = X[i, k_0 * T.int64(32) 
+ ax0]
-                    for k_1 in range(T.int64(32)):
-                        with T.block("Y"):
-                            Y[i, k_0 * T.int64(32) + k_1] = 
X_global[T.int64(0), k_1]
-
-    mod = Before
-    mod = tvm.tir.transform.CompactBufferAllocation()(mod)
-    after = tvm.tir.transform.Simplify()(mod)
-    tvm.ir.assert_structural_equal(after, Expected)
-
-
-def test_compact_symbolic_bound1():
+    # it is not testcase on block form
+    is_lower_order_free = False
+
+    @T.prim_func
+    def before():
+        """Partially accessed buffer, but should not compact
+        because existence of aliasing buffer B."""
+        data = T.allocate([1024], "int8")
+        A = T.decl_buffer([1024], "int8", data)
+        B = T.decl_buffer([512], "float16", data)
+        for i in range(10):
+            A[i] = A[i] + T.int8(1)
+        for i in range(10):
+            B[i] = B[i] + T.float16(1)
+
+    expected = before
+
+
+class TestNotCompactBufferWithDifferentDtype(BaseCompactTest):
+
+    # it is not testcase on block form
+    is_lower_order_free = False
+
+    @T.prim_func
+    def before():
+        """Partially accessed buffer, but should not compact
+        because existence of aliasing buffer B."""
+        data = T.allocate([1024], "int8")
+        A = T.decl_buffer([256], "int32", data)
+        for i in range(10):
+            A[i] = A[i] + 1
+
+    expected = before
+
+
+class TestNonBoolCondition(BaseCompactTest):
+
+    # it is not testcase on block form
+    is_lower_order_free = False
+
+    @T.prim_func
+    def before():
+        data = T.allocate([12], "int32")
+        A = T.Buffer([12], "int32", data)
+        for i in range(10):
+            if i:
+                A[i] = A[i] + 1
+
+    @T.prim_func
+    def expected():
+        data = T.allocate([9], "int32")
+        A = T.Buffer([9], "int32", data)
+        for i in range(10):
+            if i:
+                A[i - 1] = A[i - 1] + 1
+
+
+def test_lower_te():
+    x = te.placeholder((1,))
+    y = te.compute((1,), lambda i: x[i] + 2)
+    s = te.create_schedule(y.op)
+    orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
+    mod = tvm.tir.transform.CompactBufferAllocation()(orig_mod)
+    tvm.ir.assert_structural_equal(mod, orig_mod)  # CompactBufferAllocation 
should do nothing on TE
+
+
+class TestCompactSymbolicBound0:
     """Test symbolic bound that get compacted to constant"""
 
-    @tvm.script.ir_module
-    class Before:
-        @T.prim_func
-        def main(x: T.handle, y: T.handle, n: T.int64):
-            X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
-            Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
-            for i, k_0 in T.grid(T.int64(8), n):
-                with T.block(""):
-                    X_global = T.alloc_buffer((T.int64(8), n * T.int64(32)))
+    @T.prim_func
+    def before(x: T.handle, y: T.handle, n: T.int64):
+        X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+        Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+        for i, k_0 in T.grid(T.int64(8), n):
+            with T.block(""):
+                X_global = T.alloc_buffer((T.int64(8), n * T.int64(32)))
+                for ax0 in range(T.int64(32)):
                     with T.block("X_global"):
-                        for x0 in range(T.int64(32)):
-                            X_global[i, k_0 * T.int64(32) + x0] = X[i, k_0 * 
T.int64(32) + x0]
+                        X_global[i, k_0 * T.int64(32) + ax0] = X[i, k_0 * 
T.int64(32) + ax0]
+                for k_1 in range(T.int64(32)):
                     with T.block("Y"):
-                        for x1 in range(T.int64(32)):
-                            Y[i, k_0 * T.int64(32) + x1] = X_global[i, k_0 * 
T.int64(32) + x1]
-
-    @tvm.script.ir_module
-    class Expected:
-        @T.prim_func
-        def main(x: T.handle, y: T.handle, n: T.int64):
-            X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
-            Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
-            # with T.block("root"):
-            for i, k_0 in T.grid(T.int64(8), n):
-                with T.block(""):
-                    X_global = T.alloc_buffer((T.int64(1), T.int64(32)))
+                        Y[i, k_0 * T.int64(32) + k_1] = X_global[i, k_0 * 
T.int64(32) + k_1]
+
+    @T.prim_func
+    def expected(x: T.handle, y: T.handle, n: T.int64):
+        X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+        Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+        for i, k_0 in T.grid(T.int64(8), n):
+            with T.block(""):
+                X_global = T.alloc_buffer((T.int64(1), T.int64(32)))
+                for ax0 in range(T.int64(32)):
                     with T.block("X_global"):
-                        for x0 in range(T.int64(32)):
-                            X_global[T.int64(0), x0] = X[i, k_0 * T.int64(32) 
+ x0]
+                        X_global[T.int64(0), ax0] = X[i, k_0 * T.int64(32) + 
ax0]
+                for k_1 in range(T.int64(32)):
                     with T.block("Y"):
-                        for x1 in range(T.int64(32)):
-                            Y[i, k_0 * T.int64(32) + x1] = 
X_global[T.int64(0), x1]
+                        Y[i, k_0 * T.int64(32) + k_1] = X_global[T.int64(0), 
k_1]
 
-    mod = Before
-    mod = tvm.tir.transform.CompactBufferAllocation()(mod)
-    after = tvm.tir.transform.Simplify()(mod)
-    tvm.ir.assert_structural_equal(after, Expected)
+
+class TestCompactSymbolicBound1:
+    """Test symbolic bound that get compacted to constant"""
+
+    @T.prim_func
+    def before(x: T.handle, y: T.handle, n: T.int64):
+        X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+        Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+        for i, k_0 in T.grid(T.int64(8), n):
+            with T.block(""):
+                X_global = T.alloc_buffer((T.int64(8), n * T.int64(32)))
+                with T.block("X_global"):
+                    for x0 in range(T.int64(32)):
+                        X_global[i, k_0 * T.int64(32) + x0] = X[i, k_0 * 
T.int64(32) + x0]
+                with T.block("Y"):
+                    for x1 in range(T.int64(32)):
+                        Y[i, k_0 * T.int64(32) + x1] = X_global[i, k_0 * 
T.int64(32) + x1]
+
+    @T.prim_func
+    def expected(x: T.handle, y: T.handle, n: T.int64):
+        X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
+        Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
+        # with T.block("root"):
+        for i, k_0 in T.grid(T.int64(8), n):
+            with T.block(""):
+                X_global = T.alloc_buffer((T.int64(1), T.int64(32)))
+                with T.block("X_global"):
+                    for x0 in range(T.int64(32)):
+                        X_global[T.int64(0), x0] = X[i, k_0 * T.int64(32) + x0]
+                with T.block("Y"):
+                    for x1 in range(T.int64(32)):
+                        Y[i, k_0 * T.int64(32) + x1] = X_global[T.int64(0), x1]
 
 
 if __name__ == "__main__":


Reply via email to