This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5711c35ae0 [TIR Pass] decouple flatten buffer to lower opaque block 
pass and flatten buffer. (#12172)
5711c35ae0 is described below

commit 5711c35ae01b1dec3527726149efc6fe1a4bc7c6
Author: Fred.Jia <[email protected]>
AuthorDate: Wed Jul 27 14:22:18 2022 +0800

    [TIR Pass] decouple flatten buffer to lower opaque block pass and flatten 
buffer. (#12172)
---
 include/tvm/tir/transform.h                        |  11 +-
 python/tvm/script/tir/scope_handler.py             |  10 +-
 python/tvm/tir/transform/transform.py              |  16 +-
 src/driver/driver_api.cc                           |   1 +
 src/meta_schedule/postproc/verify_gpu_code.cc      |   1 +
 src/tir/transforms/flatten_buffer.cc               | 135 ++---------
 .../{flatten_buffer.cc => lower_opaque_block.cc}   | 139 ++---------
 tests/python/unittest/test_tir_buffer.py           |   1 +
 .../unittest/test_tir_transform_flatten_buffer.py  | 261 +++++----------------
 .../test_tir_transform_inject_ptx_async_copy.py    |   2 +
 .../unittest/test_tir_transform_loop_partition.py  |   1 +
 ...py => test_tir_transform_lower_opaque_block.py} | 131 ++++-------
 12 files changed, 177 insertions(+), 532 deletions(-)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 005bf84103..c758a00b3f 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -457,9 +457,14 @@ TVM_DLL Pass LegalizePackedCalls();
 TVM_DLL Pass LowerMatchBuffer();
 
 /*!
- * \brief Flatten the multi-dimensional BufferLoad and BufferStore
- *        to single dimensional Load/Store. Also remove Block to
- *        ensure that the flattened TIR can not be scheduled again.
+ * \brief Remove the block to ensure that the TIR can not be scheduled again.
+ * \return The pass.
+ */
+TVM_DLL Pass LowerOpaqueBlock();
+
+/*!
+ * \brief Flatten the multi-dimensional BufferLoad and BufferStore to single 
dimensional
+ *        BufferLoad/BufferStore for the TIR not contains opaque block.
  * \return The pass.
  */
 TVM_DLL Pass FlattenBuffer();
diff --git a/python/tvm/script/tir/scope_handler.py 
b/python/tvm/script/tir/scope_handler.py
index 76fbf26eea..92aaf8b4d9 100644
--- a/python/tvm/script/tir/scope_handler.py
+++ b/python/tvm/script/tir/scope_handler.py
@@ -111,16 +111,10 @@ class Allocate(WithScopeHandler):
             condition = tvm.runtime.convert(condition)
             scope = tvm.runtime.convert(scope)
 
-            # Currently, allocate nodes should only occur after buffer
-            # flattening has been applied.  This can be simplified in
-            # the future by having the AllocateNode hold a buffer
-            # object directly.
-            flattened = self.buffer.get_flattened_buffer()
-
             return tvm.tir.Allocate(
                 self.buffer.data,
-                flattened.dtype,
-                flattened.shape,
+                self.buffer.dtype,
+                self.buffer.shape,
                 condition,
                 self.body,
                 annotations=annotations,
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index 2a4ff6618a..6cc7b2e1f8 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -769,10 +769,20 @@ def LowerMatchBuffer():
     return _ffi_api.LowerMatchBuffer()  # type: ignore
 
 
+def LowerOpaqueBlock():
+    """Remove the block to ensure that the TIR can not be scheduled again.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.LowerOpaqueBlock()  # type: ignore
+
+
 def FlattenBuffer():
-    """Flatten the multi-dimensional BufferLoad and BufferStore
-    to single dimensional Load/Store. Also remove Block to
-    ensure that the flattened TIR can not be scheduled again.
+    """Flatten the multi-dimensional BufferLoad and BufferStore to single 
dimensional
+    BufferLoad/BufferStore for the TIR not contains opaque block.
 
     Returns
     -------
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 0446347eca..6f4fb618d3 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -202,6 +202,7 @@ Array<tvm::transform::Pass> CreatePassList(bool 
disable_loop_partition) {
   pass_list.push_back(tir::transform::CompactBufferAllocation());
   pass_list.push_back(tir::transform::LowerMatchBuffer());
   pass_list.push_back(tir::transform::InjectSoftwarePipeline());
+  pass_list.push_back(tir::transform::LowerOpaqueBlock());
   pass_list.push_back(tir::transform::FlattenBuffer());
   pass_list.push_back(tir::transform::LowerVtcmAlloc());
   pass_list.push_back(tir::transform::BF16Legalize());
diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc 
b/src/meta_schedule/postproc/verify_gpu_code.cc
index 857b732c98..dfe2c5a06a 100644
--- a/src/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/meta_schedule/postproc/verify_gpu_code.cc
@@ -164,6 +164,7 @@ class VerifyGPUCodeNode : public PostprocNode {
           pass_list.push_back(tir::transform::CompactBufferAllocation());
           pass_list.push_back(tir::transform::LowerMatchBuffer());
           pass_list.push_back(tir::transform::InjectSoftwarePipeline());
+          pass_list.push_back(tir::transform::LowerOpaqueBlock());
           pass_list.push_back(tir::transform::FlattenBuffer());
           pass_list.push_back(tir::transform::BF16Legalize());
           pass_list.push_back(tir::transform::NarrowDataType(32));
diff --git a/src/tir/transforms/flatten_buffer.cc 
b/src/tir/transforms/flatten_buffer.cc
index 21de191db0..dcc23a72b2 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -21,32 +21,17 @@
  * \file flatten_buffer.cc
  */
 
-#include <tvm/tir/builtin.h>
-#include <tvm/tir/function.h>
-#include <tvm/tir/op.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
-#include "../../support/utils.h"
 #include "ir_utils.h"
 
 namespace tvm {
 namespace tir {
 
-PrimExpr BufferArea(const Buffer& buffer) {
-  if (buffer->strides.size()) {
-    ICHECK(buffer->shape.size() == buffer->strides.size());
-    return buffer->strides[0] * buffer->shape[0];
-  }
-  PrimExpr area = Integer(1);
-  for (const PrimExpr& dim : buffer->shape) {
-    area = area * dim;
-  }
-  return area;
-}
-
 /*!
  * \brief Transform multi-dimension BufferLoad/BufferStore into 
device-supported dimension
+ *        for the TIR not contains opaque block.
  */
 class BufferFlattener : public StmtExprMutator {
  public:
@@ -68,76 +53,25 @@ class BufferFlattener : public StmtExprMutator {
     }
   }
 
-  Stmt VisitStmt_(const BlockRealizeNode* op) final {
-    // We have convert blocks into opaque blocks in previous passes.
-    ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in 
FlattenBuffer. Please "
-                                       "call pass ConvertBlocksToOpaque 
before.";
-    // Step 1. Visit the body
-    Block new_block = Downcast<Block>(this->VisitStmt(op->block));
-    PrimExpr predicate = this->VisitExpr(op->predicate);
-    // Step 2. Transform the `predicate` to if-then-else
-    Stmt body = new_block->body;
-    if (!is_one(predicate)) {
-      body = IfThenElse(predicate, std::move(body));
-    }
-    // Step 3. Handle allocations in reverse order
-    for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
-      Buffer buffer = GetFlattenedBuffer(new_block->alloc_buffers[i - 1]);
-      body = Allocate(buffer->data, buffer->dtype, buffer->shape, 
const_true(), std::move(body));
-    }
-    return body;
-  }
-
-  Stmt VisitStmt_(const ForNode* op) final {
-    // Step 1. Update unit loop info.
-    PrimExpr min = this->VisitExpr(op->min);
-    PrimExpr extent = this->VisitExpr(op->extent);
-    if (is_one(extent) && op->annotations.empty()) {
-      // handling unit loop
-      unit_loop_vars_[op->loop_var] = min;
-    }
-    // Step 2. Visit recursively
-    Stmt body = this->VisitStmt(op->body);
-    // Step 3. Create new For loop accordingly
-    if (op->kind == ForKind::kThreadBinding) {
-      // Case 1. Thread binding
-      ICHECK(op->thread_binding.defined());
-      String thread_tag = op->thread_binding.value()->thread_tag;
-      body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body);
-    } else if (is_one(extent) && op->annotations.empty()) {
-      // Case 2. Unit loop
-      return body;
-    } else {
-      // Case 3. An ordinary loop
-      body = For(op->loop_var, std::move(min), std::move(extent), op->kind, 
std::move(body));
-    }
-    // Step 4. Handle annotations
-    std::set<std::string> ordered_ann_keys;
-    for (const auto& annotation : op->annotations) {
-      ordered_ann_keys.insert(annotation.first);
-    }
-    for (auto it = ordered_ann_keys.rbegin(); it != ordered_ann_keys.rend(); 
++it) {
-      const std::string& ann_key = *it;
-      const ObjectRef& ann_value = op->annotations.at(ann_key);
-      if (attr::IsPragmaKey(ann_key)) {
-        body =
-            AttrStmt(op->loop_var, ann_key, ConvertAttrValue(ann_key, 
ann_value), std::move(body));
-      }
+  Stmt VisitStmt_(const AllocateNode* op) final {
+    Allocate alloc = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
+    // TODO(Lunderberg): Move the handling of boolean into a
+    // dedicated pass.
+    if (alloc->dtype == DataType::Bool()) {
+      auto writer = alloc.CopyOnWrite();
+      writer->dtype = DataType::Int(8);
     }
-    return body;
-  }
-
-  PrimExpr VisitExpr_(const VarNode* op) final {
-    Var var = GetRef<Var>(op);
-    auto it = unit_loop_vars_.find(var);
-    if (it == unit_loop_vars_.end()) {
-      return std::move(var);
+    // Handle multi-dimension allocations
+    if (alloc->extents.size() == 1) {
+      return std::move(alloc);
     } else {
-      PrimExpr expr = it->second;
-      if (expr.dtype() != var.dtype()) {
-        expr = tvm::cast(var.dtype(), std::move(expr));
+      Array<PrimExpr> flat_extent(static_cast<size_t>(1), 1);
+      for (size_t i = 0; i < alloc->extents.size(); i++) {
+        flat_extent.Set(0, flat_extent[0] * alloc->extents[i]);
       }
-      return expr;
+      auto n = alloc.CopyOnWrite();
+      n->extents = flat_extent;
+      return std::move(alloc);
     }
   }
 
@@ -146,7 +80,6 @@ class BufferFlattener : public StmtExprMutator {
     if (it != buffer_remap_.end()) {
       return it->second;
     }
-
     auto flattened = buf.GetFlattenedBuffer();
 
     // TODO(Lunderberg): Move the handling of boolean into a
@@ -208,40 +141,6 @@ class BufferFlattener : public StmtExprMutator {
     return node;
   }
 
-  static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String 
thread_tag,
-                               Stmt body) {
-    IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent),
-                     /*var=*/std::move(var),
-                     /*iter_type=*/IterVarType::kThreadIndex,
-                     /*thread_tag=*/thread_tag);
-    String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" ||
-                       thread_tag == "vthread.y" || thread_tag == "vthread.z")
-                          ? attr::virtual_thread
-                          : attr::thread_extent;
-    return AttrStmt(/*node=*/std::move(iter_var),
-                    /*attr_key=*/std::move(attr_key),
-                    /*value=*/std::move(extent),
-                    /*body=*/std::move(body));
-  }
-
-  /*! \brief Convert attr value from annotation map into PrimExpr. */
-  PrimExpr ConvertAttrValue(const String& key, const ObjectRef& obj) {
-    if (!obj.defined()) {
-      return PrimExpr();
-    } else if (const PrimExprNode* expr = obj.as<PrimExprNode>()) {
-      return GetRef<PrimExpr>(expr);
-    } else if (const StringObj* str = obj.as<StringObj>()) {
-      return std::move(StringImm(str->data));
-    } else {
-      LOG(FATAL) << "Illegal attribute of key " << key << ", value type " << 
obj->GetTypeKey()
-                 << " not supported";
-      return PrimExpr();
-    }
-  }
-
-  /*! \brief Record the loop_var and loop start value of unit loops, whose 
extent is one. */
-  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
unit_loop_vars_;
-
   /*! \brief Map of buffers being remapped. */
   std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> 
buffer_remap_;
 
diff --git a/src/tir/transforms/flatten_buffer.cc 
b/src/tir/transforms/lower_opaque_block.cc
similarity index 56%
copy from src/tir/transforms/flatten_buffer.cc
copy to src/tir/transforms/lower_opaque_block.cc
index 21de191db0..69d8787aa1 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/lower_opaque_block.cc
@@ -18,56 +18,22 @@
  */
 
 /*!
- * \file flatten_buffer.cc
+ * \file lower_opaque_block.cc
  */
 
-#include <tvm/tir/builtin.h>
-#include <tvm/tir/function.h>
-#include <tvm/tir/op.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
-#include "../../support/utils.h"
 #include "ir_utils.h"
 
 namespace tvm {
 namespace tir {
 
-PrimExpr BufferArea(const Buffer& buffer) {
-  if (buffer->strides.size()) {
-    ICHECK(buffer->shape.size() == buffer->strides.size());
-    return buffer->strides[0] * buffer->shape[0];
-  }
-  PrimExpr area = Integer(1);
-  for (const PrimExpr& dim : buffer->shape) {
-    area = area * dim;
-  }
-  return area;
-}
-
 /*!
- * \brief Transform multi-dimension BufferLoad/BufferStore into 
device-supported dimension
+ * \brief Remove Block to ensure that the TIR can not be scheduled again.
  */
-class BufferFlattener : public StmtExprMutator {
- public:
-  static PrimFunc Flatten(PrimFunc func) {
-    Map<Var, Buffer> preflattened_buffer_map =
-        Merge(func->buffer_map, func->preflattened_buffer_map);
-    auto pass = BufferFlattener(func->buffer_map);
-    auto writer = func.CopyOnWrite();
-    writer->body = pass.VisitStmt(func->body);
-    writer->preflattened_buffer_map = preflattened_buffer_map;
-    writer->buffer_map = pass.updated_extern_buffer_map_;
-    return func;
-  }
-
+class OpaqueBlockLower : public StmtExprMutator {
  private:
-  explicit BufferFlattener(const Map<Var, Buffer>& extern_buffer_map) {
-    for (const auto& kv : extern_buffer_map) {
-      updated_extern_buffer_map_.Set(kv.first, GetFlattenedBuffer(kv.second));
-    }
-  }
-
   Stmt VisitStmt_(const BlockRealizeNode* op) final {
     // We have convert blocks into opaque blocks in previous passes.
     ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in 
FlattenBuffer. Please "
@@ -82,8 +48,16 @@ class BufferFlattener : public StmtExprMutator {
     }
     // Step 3. Handle allocations in reverse order
     for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
-      Buffer buffer = GetFlattenedBuffer(new_block->alloc_buffers[i - 1]);
-      body = Allocate(buffer->data, buffer->dtype, buffer->shape, 
const_true(), std::move(body));
+      const Buffer& buffer = new_block->alloc_buffers[i - 1];
+      Array<PrimExpr> new_shape = buffer->shape;
+      if (buffer->strides.size()) {
+        ICHECK_EQ(buffer->shape.size(), buffer->strides.size());
+        for (size_t i = buffer->strides.size() - 1; i > 0; --i) {
+          ICHECK(is_zero(floormod(buffer->strides[i - 1], 
buffer->strides[i])));
+          new_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
+        }
+      }
+      body = Allocate(buffer->data, buffer->dtype, new_shape, const_true(), 
std::move(body));
     }
     return body;
   }
@@ -141,73 +115,6 @@ class BufferFlattener : public StmtExprMutator {
     }
   }
 
-  Buffer GetFlattenedBuffer(Buffer buf) {
-    auto it = buffer_remap_.find(buf);
-    if (it != buffer_remap_.end()) {
-      return it->second;
-    }
-
-    auto flattened = buf.GetFlattenedBuffer();
-
-    // TODO(Lunderberg): Move the handling of boolean into a
-    // dedicated pass.
-    if (flattened->dtype == DataType::Bool()) {
-      auto writer = flattened.CopyOnWrite();
-      writer->dtype = DataType::Int(8);
-    }
-
-    buffer_remap_[buf] = flattened;
-    return flattened;
-  }
-
-  Stmt VisitStmt_(const BufferStoreNode* op) final {
-    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
-    bool store_returns_bool = (op->value.dtype() == DataType::Bool());
-    store = VisitBufferAccess(store);
-
-    // Handle casts from the value's dtype to the dtype of the
-    // backing array.
-    // TODO(Lunderberg): Move the handling of boolean into a
-    // dedicated pass.
-    if (store_returns_bool) {
-      ICHECK_EQ(store->buffer->dtype, DataType::Int(8))
-          << "Expected int8 backing array for boolean tensor";
-      auto writer = store.CopyOnWrite();
-      writer->value = tvm::cast(DataType::Int(8), store->value);
-      return store;
-    }
-    return store;
-  }
-
-  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
-    bool load_returns_bool = (op->dtype == DataType::Bool());
-    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
-    load = VisitBufferAccess(load);
-    // Handle casts from dtype of the backing array to value's dtype.
-    // TODO(Lunderberg): Move the handling of boolean into a
-    // dedicated pass.
-    if (load_returns_bool) {
-      ICHECK_EQ(load->buffer->dtype, DataType::Int(8))
-          << "Expected int8 backing array for boolean tensor";
-      load.CopyOnWrite()->dtype = DataType::Int(8);
-      return tvm::cast(DataType::Bool(), load);
-    } else {
-      return std::move(load);
-    }
-  }
-
-  template <typename Node>
-  Node VisitBufferAccess(Node node) {
-    ICHECK(node->buffer.defined());
-    auto flattened_indices = node->buffer->ElemOffset(node->indices);
-    Buffer flattened_buffer = GetFlattenedBuffer(node->buffer);
-
-    auto writer = node.CopyOnWrite();
-    writer->buffer = flattened_buffer;
-    writer->indices = flattened_indices;
-    return node;
-  }
-
   static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String 
thread_tag,
                                Stmt body) {
     IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent),
@@ -241,18 +148,14 @@ class BufferFlattener : public StmtExprMutator {
 
   /*! \brief Record the loop_var and loop start value of unit loops, whose 
extent is one. */
   std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
unit_loop_vars_;
-
-  /*! \brief Map of buffers being remapped. */
-  std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> 
buffer_remap_;
-
-  /*! \brief The updated external buffer map. */
-  Map<Var, Buffer> updated_extern_buffer_map_;
 };
 
-PrimFunc FlattenBuffer(PrimFunc f) {
+PrimFunc LowerOpaqueBlock(PrimFunc f) {
   // Only apply this pass to TIR that is not from TE schedules
   if (!IsFromLegacyTESchedule(f)) {
-    return BufferFlattener::Flatten(f);
+    auto fptr = f.CopyOnWrite();
+    fptr->body = OpaqueBlockLower()(std::move(fptr->body));
+    return f;
   } else {
     return f;
   }
@@ -260,14 +163,14 @@ PrimFunc FlattenBuffer(PrimFunc f) {
 
 namespace transform {
 
-Pass FlattenBuffer() {
+Pass LowerOpaqueBlock() {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
-    return FlattenBuffer(std::move(f));
+    return LowerOpaqueBlock(std::move(f));
   };
-  return CreatePrimFuncPass(pass_func, 0, "tir.FlattenBuffer", {});
+  return CreatePrimFuncPass(pass_func, 0, "tir.LowerOpaqueBlock", {});
 }
 
-TVM_REGISTER_GLOBAL("tir.transform.FlattenBuffer").set_body_typed(FlattenBuffer);
+TVM_REGISTER_GLOBAL("tir.transform.LowerOpaqueBlock").set_body_typed(LowerOpaqueBlock);
 }  // namespace transform
 
 }  // namespace tir
diff --git a/tests/python/unittest/test_tir_buffer.py 
b/tests/python/unittest/test_tir_buffer.py
index 10e827978c..d250fada6a 100644
--- a/tests/python/unittest/test_tir_buffer.py
+++ b/tests/python/unittest/test_tir_buffer.py
@@ -115,6 +115,7 @@ def test_buffer_vload_nullptr():
             [
                 tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(),
                 tvm.tir.transform.CompactBufferAllocation(),
+                tvm.tir.transform.LowerOpaqueBlock(),
                 tvm.tir.transform.FlattenBuffer(),
             ]
         )(mod)
diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py 
b/tests/python/unittest/test_tir_transform_flatten_buffer.py
index f1a33a4fb2..ea9c604e71 100644
--- a/tests/python/unittest/test_tir_transform_flatten_buffer.py
+++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py
@@ -15,7 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import te, tir
+import tvm.testing
+from tvm import te
 from tvm.script import tir as T
 
 
@@ -28,24 +29,15 @@ def _check(original, transformed):
 
 
 @T.prim_func
-def compacted_elementwise_func(a: T.handle, c: T.handle) -> None:
+def elementwise_func(a: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, (16, 16), "float32")
     C = T.match_buffer(c, (16, 16), "float32")
-    for i in range(0, 16):
-        with T.block():
-            T.reads(A[i, 0:16])
-            T.writes(C[i, 0:16])
-            B = T.alloc_buffer([1, 16], "float32", scope="global")
-            for j in range(0, 16):
-                with T.block():
-                    T.reads(A[i, j])
-                    T.writes(B[0, j])
-                    B[0, j] = A[i, j] + 1.0
-            for j in range(0, 16):
-                with T.block():
-                    T.reads(B[0, j])
-                    T.writes(C[i, j])
-                    C[i, j] = B[0, j] * 2.0
+    for i in T.serial(0, 16):
+        B_new = T.allocate([1, 16], "float32", "global")
+        for j in T.serial(0, 16):
+            B_new[0, j] = A[i, j] + 1.0
+        for j in T.serial(0, 16):
+            C[i, j] = B_new[0, j] * 2.0
 
 
 @T.prim_func
@@ -63,26 +55,22 @@ def flattened_elementwise_func(a: T.handle, c: T.handle) -> 
None:
 
 
 @T.prim_func
-def compacted_gpu_func(a: T.handle, c: T.handle) -> None:
+def gpu_func(a: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, (16, 16), "float32")
     C = T.match_buffer(c, (16, 16), "float32")
-    for i0 in T.thread_binding(0, 4, thread="blockIdx.x"):
-        for i1 in T.thread_binding(0, 2, thread="threadIdx.x"):
-            for i2 in T.thread_binding(0, 2, thread="vthread"):
-                with T.block():
-                    T.reads(A[i0 * 4 + i1 * 2 + i2, 0:16])
-                    T.writes(C[i0 * 4 + i1 * 2 + i2, 0:16])
-                    B = T.alloc_buffer([1, 16], "float32", scope="local")
-                    for j in range(0, 16):
-                        with T.block():
-                            T.reads(A[i0 * 4 + i1 * 2 + i2, j])
-                            T.writes(B[0, j])
-                            B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
-                    for j in range(0, 16):
-                        with T.block():
-                            T.reads(B[0, j])
-                            T.writes(C[i0 * 4 + i1 * 2 + i2, j])
-                            C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
+
+    i0 = T.env_thread("blockIdx.x")
+    i1 = T.env_thread("threadIdx.x")
+    i2 = T.env_thread("vthread")
+
+    T.launch_thread(i0, 4)
+    T.launch_thread(i1, 2)
+    T.launch_thread(i2, 2)
+    B = T.allocate([1, 16], "float32", "local")
+    for j in range(0, 16):
+        B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
+    for j in range(0, 16):
+        C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
 
 
 @T.prim_func
@@ -107,25 +95,16 @@ def flattened_gpu_func(a: T.handle, c: T.handle) -> None:
 
 
 @T.prim_func
-def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) 
-> None:
+def symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
     A = T.match_buffer(a, (n, m), "float32")
     C = T.match_buffer(c, (n, m), "float32")
 
     for i in range(0, n):
-        with T.block():
-            T.reads(A[i, m])
-            T.writes(C[i, m])
-            B = T.alloc_buffer((m,), "float32", scope="global")
-            for j in range(0, m):
-                with T.block():
-                    T.reads(A[i, j])
-                    T.writes(B[j])
-                    B[j] = A[i, j] + 1.0
-            for j in range(0, m):
-                with T.block():
-                    T.reads(B[j])
-                    T.writes(C[i, j])
-                    C[i, j] = B[j] * 2.0
+        B = T.allocate([m], "float32", "global")
+        for j in range(0, m):
+            B[j] = A[i, j] + 1.0
+        for j in range(0, m):
+            C[i, j] = B[j] * 2.0
 
 
 @T.prim_func
@@ -144,105 +123,44 @@ def flattened_symbolic_func(a: T.handle, c: T.handle, n: 
T.int32, m: T.int32) ->
 
 
 @T.prim_func
-def compacted_predicate_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (32), "float32")
-    C = T.match_buffer(c, (32), "float32")
-
-    for i, j in T.grid(5, 7):
-        with T.block():
-            T.reads(A[i * 7 + j])
-            T.writes(C[i * 7 + j])
-            T.where(i * 7 + j < 32)
-            C[i * 7 + j] = A[i * 7 + j] + 1.0
-
-
[email protected]_func
-def flattened_predicate_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (32), "float32")
-    C = T.match_buffer(c, (32), "float32")
-    T.preflattened_buffer(A, (32), "float32", data=A.data)
-    T.preflattened_buffer(C, (32), "float32", data=C.data)
-
-    for i, j in T.grid(5, 7):
-        if i * 7 + j < 32:
-            C[i * 7 + j] = A[i * 7 + j] + 1.0
-
-
[email protected]_func
-def compacted_unit_loop_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (32), "float32")
-    C = T.match_buffer(c, (32), "float32")
-
-    for x, y, z in T.grid(4, 1, 8):
-        with T.block():
-            T.reads(A[x * 8 + y * 8 + z])
-            T.writes(C[x * 8 + y * 8 + z])
-            C[x * 8 + y * 8 + z] = A[x * 8 + y * 8 + z] + 1.0
-
-
[email protected]_func
-def flattened_unit_loop_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (32), "float32")
-    C = T.match_buffer(c, (32), "float32")
-    T.preflattened_buffer(A, (32), "float32", data=A.data)
-    T.preflattened_buffer(C, (32), "float32", data=C.data)
-
-    for x, z in T.grid(4, 8):
-        C[x * 8 + z] = A[x * 8 + z] + 1.0
-
-
[email protected]_func
-def compacted_multi_alloc_func(a: T.handle, d: T.handle) -> None:
-    A = T.match_buffer(a, (32), "float32")
-    D = T.match_buffer(d, (32), "float32")
+def multi_alloc_func(a: T.handle, d: T.handle) -> None:
+    A = T.match_buffer(a, (4, 32), "float32")
+    D = T.match_buffer(d, (4, 32), "float32")
 
-    for i in range(0, 32):
-        with T.block():
-            T.reads(A[i])
-            T.writes(D[i])
-            B = T.alloc_buffer((32,), scope="global")
-            C = T.alloc_buffer((32,), scope="global")
-            B[i] = A[i] + 1.0
-            C[i] = A[i] + B[i]
-            D[i] = C[i] * 2.0
+    for i, j in T.grid(4, 32):
+        B = T.allocate((4, 32), "float32", scope="global")
+        C = T.allocate((4, 32), "float32", scope="global")
+        B[i, j] = A[i, j] + 1.0
+        C[i, j] = A[i, j] + B[i, j]
+        D[i, j] = C[i, j] * 2.0
 
 
 @T.prim_func
 def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None:
-    A = T.match_buffer(a, (32), "float32")
-    D = T.match_buffer(d, (32), "float32")
-    T.preflattened_buffer(A, (32), "float32", data=A.data)
-    T.preflattened_buffer(D, (32), "float32", data=D.data)
+    A = T.match_buffer(a, (128), "float32")
+    D = T.match_buffer(d, (128), "float32")
+    T.preflattened_buffer(A, (4, 32), "float32", data=A.data)
+    T.preflattened_buffer(D, (4, 32), "float32", data=D.data)
 
-    for i in range(0, 32):
-        B = T.allocate((32,), "float32", "global")
-        C = T.allocate((32,), "float32", "global")
-        B[i] = A[i] + 1.0
-        C[i] = A[i] + B[i]
-        D[i] = C[i] * 2.0
+    for i, j in T.grid(4, 32):
+        B = T.allocate((128), "float32", "global")
+        C = T.allocate((128), "float32", "global")
+        B[i * 32 + j] = A[i * 32 + j] + 1.0
+        C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j]
+        D[i * 32 + j] = C[i * 32 + j] * 2.0
 
 
 @T.prim_func
-def compacted_strided_buffer_func(a: T.handle, c: T.handle) -> None:
+def strided_buffer_func(a: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, (16, 16), "float32")
     C = T.match_buffer(c, (16, 16), "float32")
-    for i0 in range(0, 4):
-        with T.block():
-            T.reads(A[i0 * 4 : i0 * 4 + 4, 0:16])
-            T.writes(C[i0 * 4 : i0 * 4 + 4, 0:16])
-            B = T.alloc_buffer([4, 16], "float32", strides=[17, 1], 
scope="global")
-            for i1 in range(0, 4):
-                for j in range(0, 16):
-                    with T.block():
-                        T.reads(A[i0 * 4 + i1, j])
-                        T.writes(B[i1, j])
-                        B[i1, j] = A[i0 * 4 + i1, j] + 1.0
-            for i1 in range(0, 4):
-                for j in range(0, 16):
-                    with T.block():
-                        T.reads(B[i1, j])
-                        T.writes(C[i0 * 4 + i1, j])
-                        C[i0 * 4 + i1, j] = B[i1, j] * 2.0
+    for i0 in T.serial(4):
+        B = T.allocate([4, 17], "float32", "global")
+        B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, 
strides=[17, 1])
+        for i1, j in T.grid(4, 16):
+            B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0
+        for i1, j in T.grid(4, 16):
+            C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0
 
 
 @T.prim_func
@@ -261,20 +179,10 @@ def flattened_strided_buffer_func(a: T.handle, c: 
T.handle) -> None:
                 C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0
 
 
[email protected]_func
-def annotated_loops(a: T.handle) -> None:
-    A = T.match_buffer(a, (16,), "float32")
-    for i in range(0, 16, annotations={"pragma_1": "str_value", "pragma_2": 1, 
"pragma_3": 0.0}):
-        A[i] = 0.0
-
-
 @T.prim_func
 def boolean_handling_before(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) 
-> None:
     for i0 in T.serial(10):
-        with T.block("b"):
-            T.reads(a[i0])
-            T.writes(b[i0])
-            b[i0] = a[i0]
+        b[i0] = a[i0]
 
 
 @T.prim_func
@@ -286,41 +194,24 @@ def boolean_handling_after(a: T.Buffer[10, "int8"], b: 
T.Buffer[10, "int8"]) ->
         b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
 
 
[email protected]_func
-def boolean_handle_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> 
None:
-    T.preflattened_buffer(a, [10], dtype="bool", data=a.data)
-    T.preflattened_buffer(b, [10], dtype="bool", data=b.data)
-    # body
-    for i0 in T.serial(10):
-        b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
-
-
 def test_elementwise():
-    _check(compacted_elementwise_func, flattened_elementwise_func)
+    _check(elementwise_func, flattened_elementwise_func)
 
 
 def test_gpu_workload():
-    _check(compacted_gpu_func, flattened_gpu_func)
+    _check(gpu_func, flattened_gpu_func)
 
 
 def test_symbolic_shape():
-    _check(compacted_symbolic_func, flattened_symbolic_func)
-
-
-def test_predicate():
-    _check(compacted_predicate_func, flattened_predicate_func)
-
-
-def test_unit_loops():
-    _check(compacted_unit_loop_func, flattened_unit_loop_func)
+    _check(symbolic_func, flattened_symbolic_func)
 
 
 def test_multi_alloc():
-    _check(compacted_multi_alloc_func, flattened_multi_alloc_func)
+    _check(multi_alloc_func, flattened_multi_alloc_func)
 
 
 def test_strided_buffer():
-    _check(compacted_strided_buffer_func, flattened_strided_buffer_func)
+    _check(strided_buffer_func, flattened_strided_buffer_func)
 
 
 def test_lower_te():
@@ -332,35 +223,9 @@ def test_lower_te():
     tvm.ir.assert_structural_equal(mod, orig_mod)  # FlattenBuffer should do 
nothing on TE
 
 
-def test_annotated_loops():
-    mod = tvm.IRModule.from_expr(annotated_loops)
-    mod = tvm.tir.transform.FlattenBuffer()(mod)
-    # _check(annotated_loops, compacted_annotated_loops)
-    attr1 = mod["main"].body
-    attr2 = attr1.body
-    attr3 = attr2.body
-    assert attr1.attr_key == "pragma_1" and attr1.value == "str_value"
-    assert attr2.attr_key == "pragma_2"
-    tvm.ir.assert_structural_equal(attr2.value, tvm.tir.IntImm("int32", 1))
-    assert attr3.attr_key == "pragma_3"
-    tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 
0.0))
-
-
 def test_boolean_handling():
     _check(boolean_handling_before, boolean_handling_after)
-    # mod = tvm.IRModule.from_expr(boolean_handling_before)
-    # mod = tvm.tir.transform.FlattenBuffer()(mod)
-    # print(mod.script())
 
 
 if __name__ == "__main__":
-    test_elementwise()
-    test_gpu_workload()
-    test_symbolic_shape()
-    test_predicate()
-    test_unit_loops()
-    test_multi_alloc()
-    test_strided_buffer()
-    test_lower_te()
-    test_annotated_loops()
-    test_boolean_handling()
+    tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py 
b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
index d7e13f40aa..1a906b2fb6 100644
--- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -127,6 +127,7 @@ def test_inject_async_copy():
             f = generate_global_to_shared_vectorized_copy(dtype, vec_size)
 
         mod = tvm.IRModule.from_expr(f)
+        mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
         mod = tvm.tir.transform.FlattenBuffer()(mod)
         if vec_size > 1:
             mod = tvm.tir.transform.VectorizeLoop()(mod)
@@ -154,6 +155,7 @@ def test_inject_async_copy_shared_dyn():
     f = ptx_global_to_shared_dyn_copy_fp16x8
 
     mod = tvm.IRModule.from_expr(f)
+    mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
     mod = tvm.tir.transform.FlattenBuffer()(mod)
     mod = tvm.tir.transform.VectorizeLoop()(mod)
     mod = tvm.tir.transform.MergeDynamicSharedMemoryAllocations()(mod)
diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py 
b/tests/python/unittest/test_tir_transform_loop_partition.py
index 6cfe96664d..86f2b6696b 100644
--- a/tests/python/unittest/test_tir_transform_loop_partition.py
+++ b/tests/python/unittest/test_tir_transform_loop_partition.py
@@ -611,6 +611,7 @@ def concat_func_3(
 def test_condition_mutually_exclusive():
     mod = IRModule.from_expr(concat_func_3)
     with tvm.transform.PassContext(config={"tir.LoopPartition": 
{"partition_const_loop": True}}):
+        mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
         mod = tvm.tir.transform.FlattenBuffer()(mod)
         mod = tvm.tir.transform.LoopPartition()(mod)
         mod = tvm.tir.transform.Simplify()(mod)
diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py 
b/tests/python/unittest/test_tir_transform_lower_opaque_block.py
similarity index 67%
copy from tests/python/unittest/test_tir_transform_flatten_buffer.py
copy to tests/python/unittest/test_tir_transform_lower_opaque_block.py
index f1a33a4fb2..9b18c407c4 100644
--- a/tests/python/unittest/test_tir_transform_flatten_buffer.py
+++ b/tests/python/unittest/test_tir_transform_lower_opaque_block.py
@@ -15,14 +15,15 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import te, tir
+import tvm.testing
+from tvm import te
 from tvm.script import tir as T
 
 
 def _check(original, transformed):
     func = original
     mod = tvm.IRModule.from_expr(func)
-    mod = tvm.tir.transform.FlattenBuffer()(mod)
+    mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
     mod = tvm.tir.transform.Simplify()(mod)
     tvm.ir.assert_structural_equal(mod["main"], transformed, True)
 
@@ -49,17 +50,15 @@ def compacted_elementwise_func(a: T.handle, c: T.handle) -> 
None:
 
 
 @T.prim_func
-def flattened_elementwise_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, 256, "float32")
-    C = T.match_buffer(c, 256, "float32")
-    T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
-    T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
+def transformed_elementwise_func(a: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (16, 16), "float32")
+    C = T.match_buffer(c, (16, 16), "float32")
     for i in T.serial(0, 16):
-        B_new = T.allocate([16], "float32", "global")
+        B_new = T.allocate([1, 16], "float32", "global")
         for j in T.serial(0, 16):
-            B_new[j] = A[((i * 16) + j)] + 1.0
+            B_new[0, j] = A[i, j] + 1.0
         for j in T.serial(0, 16):
-            C[((i * 16) + j)] = B_new[j] * 2.0
+            C[i, j] = B_new[0, j] * 2.0
 
 
 @T.prim_func
@@ -86,11 +85,9 @@ def compacted_gpu_func(a: T.handle, c: T.handle) -> None:
 
 
 @T.prim_func
-def flattened_gpu_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, 256, "float32")
-    C = T.match_buffer(c, 256, "float32")
-    T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
-    T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
+def transformed_gpu_func(a: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (16, 16), "float32")
+    C = T.match_buffer(c, (16, 16), "float32")
 
     i0 = T.env_thread("blockIdx.x")
     i1 = T.env_thread("threadIdx.x")
@@ -99,11 +96,11 @@ def flattened_gpu_func(a: T.handle, c: T.handle) -> None:
     T.launch_thread(i0, 4)
     T.launch_thread(i1, 2)
     T.launch_thread(i2, 2)
-    B = T.allocate([16], "float32", "local")
+    B = T.allocate([1, 16], "float32", "local")
     for j in range(0, 16):
-        B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0
+        B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
     for j in range(0, 16):
-        C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0
+        C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
 
 
 @T.prim_func
@@ -129,18 +126,16 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n: 
T.int32, m: T.int32) ->
 
 
 @T.prim_func
-def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) 
-> None:
-    A = T.match_buffer(a, n * m, "float32")
-    C = T.match_buffer(c, n * m, "float32")
-    T.preflattened_buffer(A, (n, m), "float32", data=A.data)
-    T.preflattened_buffer(C, (n, m), "float32", data=C.data)
+def transformed_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: 
T.int32) -> None:
+    A = T.match_buffer(a, (n, m), "float32")
+    C = T.match_buffer(c, (n, m), "float32")
 
     for i in range(0, n):
         B = T.allocate([m], "float32", "global")
         for j in range(0, m):
-            B[j] = A[i * m + j] + 1.0
+            B[j] = A[i, j] + 1.0
         for j in range(0, m):
-            C[i * m + j] = B[j] * 2.0
+            C[i, j] = B[j] * 2.0
 
 
 @T.prim_func
@@ -157,11 +152,9 @@ def compacted_predicate_func(a: T.handle, c: T.handle) -> 
None:
 
 
 @T.prim_func
-def flattened_predicate_func(a: T.handle, c: T.handle) -> None:
+def transformed_predicate_func(a: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, (32), "float32")
     C = T.match_buffer(c, (32), "float32")
-    T.preflattened_buffer(A, (32), "float32", data=A.data)
-    T.preflattened_buffer(C, (32), "float32", data=C.data)
 
     for i, j in T.grid(5, 7):
         if i * 7 + j < 32:
@@ -181,11 +174,9 @@ def compacted_unit_loop_func(a: T.handle, c: T.handle) -> 
None:
 
 
 @T.prim_func
-def flattened_unit_loop_func(a: T.handle, c: T.handle) -> None:
+def transformed_unit_loop_func(a: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, (32), "float32")
     C = T.match_buffer(c, (32), "float32")
-    T.preflattened_buffer(A, (32), "float32", data=A.data)
-    T.preflattened_buffer(C, (32), "float32", data=C.data)
 
     for x, z in T.grid(4, 8):
         C[x * 8 + z] = A[x * 8 + z] + 1.0
@@ -208,11 +199,9 @@ def compacted_multi_alloc_func(a: T.handle, d: T.handle) 
-> None:
 
 
 @T.prim_func
-def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None:
+def transformed_multi_alloc_func(a: T.handle, d: T.handle) -> None:
     A = T.match_buffer(a, (32), "float32")
     D = T.match_buffer(d, (32), "float32")
-    T.preflattened_buffer(A, (32), "float32", data=A.data)
-    T.preflattened_buffer(D, (32), "float32", data=D.data)
 
     for i in range(0, 32):
         B = T.allocate((32,), "float32", "global")
@@ -246,19 +235,17 @@ def compacted_strided_buffer_func(a: T.handle, c: 
T.handle) -> None:
 
 
 @T.prim_func
-def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (256,), "float32")
-    C = T.match_buffer(c, (256,), "float32")
-    T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data)
-    T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data)
-    for i0 in T.serial(0, 4):
-        B_new = T.allocate([68], "float32", "global")
-        for i1 in T.serial(0, 4):
-            for j in T.serial(0, 16):
-                B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0
-        for i1 in T.serial(0, 4):
-            for j in T.serial(0, 16):
-                C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0
+def transformed_strided_buffer_func(
+    A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]
+) -> None:
+    # body
+    for i0 in T.serial(4):
+        B = T.allocate([4, 17], "float32", "global")
+        B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, 
strides=[17, 1])
+        for i1, j in T.grid(4, 16):
+            B_1[i1, j] = A[i0 * 4 + i1, j] + T.float32(1)
+        for i1, j in T.grid(4, 16):
+            C[i0 * 4 + i1, j] = B_1[i1, j] * T.float32(2)
 
 
 @T.prim_func
@@ -278,49 +265,38 @@ def boolean_handling_before(a: T.Buffer[10, "bool"], b: 
T.Buffer[10, "bool"]) ->
 
 
 @T.prim_func
-def boolean_handling_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) 
-> None:
-    T.preflattened_buffer(a, [10], dtype="bool", data=a.data)
-    T.preflattened_buffer(b, [10], dtype="bool", data=b.data)
-    # body
-    for i0 in T.serial(10):
-        b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
-
-
[email protected]_func
-def boolean_handle_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> 
None:
-    T.preflattened_buffer(a, [10], dtype="bool", data=a.data)
-    T.preflattened_buffer(b, [10], dtype="bool", data=b.data)
+def boolean_handling_after(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) 
-> None:
     # body
     for i0 in T.serial(10):
-        b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
+        b[i0] = a[i0]
 
 
 def test_elementwise():
-    _check(compacted_elementwise_func, flattened_elementwise_func)
+    _check(compacted_elementwise_func, transformed_elementwise_func)
 
 
 def test_gpu_workload():
-    _check(compacted_gpu_func, flattened_gpu_func)
+    _check(compacted_gpu_func, transformed_gpu_func)
 
 
 def test_symbolic_shape():
-    _check(compacted_symbolic_func, flattened_symbolic_func)
+    _check(compacted_symbolic_func, transformed_symbolic_func)
 
 
 def test_predicate():
-    _check(compacted_predicate_func, flattened_predicate_func)
+    _check(compacted_predicate_func, transformed_predicate_func)
 
 
 def test_unit_loops():
-    _check(compacted_unit_loop_func, flattened_unit_loop_func)
+    _check(compacted_unit_loop_func, transformed_unit_loop_func)
 
 
 def test_multi_alloc():
-    _check(compacted_multi_alloc_func, flattened_multi_alloc_func)
+    _check(compacted_multi_alloc_func, transformed_multi_alloc_func)
 
 
 def test_strided_buffer():
-    _check(compacted_strided_buffer_func, flattened_strided_buffer_func)
+    _check(compacted_strided_buffer_func, transformed_strided_buffer_func)
 
 
 def test_lower_te():
@@ -328,14 +304,13 @@ def test_lower_te():
     y = te.compute((1,), lambda i: x[i] + 2)
     s = te.create_schedule(y.op)
     orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
-    mod = tvm.tir.transform.FlattenBuffer()(orig_mod)
-    tvm.ir.assert_structural_equal(mod, orig_mod)  # FlattenBuffer should do 
nothing on TE
+    mod = tvm.tir.transform.LowerOpaqueBlock()(orig_mod)
+    tvm.ir.assert_structural_equal(mod, orig_mod)  # LowerOpaqueBlock should 
do nothing on TE
 
 
 def test_annotated_loops():
     mod = tvm.IRModule.from_expr(annotated_loops)
-    mod = tvm.tir.transform.FlattenBuffer()(mod)
-    # _check(annotated_loops, compacted_annotated_loops)
+    mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
     attr1 = mod["main"].body
     attr2 = attr1.body
     attr3 = attr2.body
@@ -348,19 +323,7 @@ def test_annotated_loops():
 
 def test_boolean_handling():
     _check(boolean_handling_before, boolean_handling_after)
-    # mod = tvm.IRModule.from_expr(boolean_handling_before)
-    # mod = tvm.tir.transform.FlattenBuffer()(mod)
-    # print(mod.script())
 
 
 if __name__ == "__main__":
-    test_elementwise()
-    test_gpu_workload()
-    test_symbolic_shape()
-    test_predicate()
-    test_unit_loops()
-    test_multi_alloc()
-    test_strided_buffer()
-    test_lower_te()
-    test_annotated_loops()
-    test_boolean_handling()
+    tvm.testing.main()

Reply via email to