This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8302524  [TENSORIR] Add `from_legacy_te_schdule` attr to TE PrimFuncs 
(#8641)
8302524 is described below

commit 8302524a381655a727db74db97777cba99a2609e
Author: Tristan Konolige <[email protected]>
AuthorDate: Thu Aug 5 12:21:01 2021 -0700

    [TENSORIR] Add `from_legacy_te_schdule` attr to TE PrimFuncs (#8641)
    
    * [TENSORIR] Add `from_legacy_te_schdule` attr to TE PrimFuncs
    
    The `from_legacy_te_schedule` marks PrimFuncs created from TE
    scheduling. Passes that only operate on TE scheduling check this attrs
    and no op if it is not found. If `from_legacy_te_schedule` is false or
    not set, then it is assumed that the PrimFunc is from TensorIR. Passes
    specific to TensorIR now check for the absence of this attr.
    
    * formatting
    
    * enable passes regardless of te or not
---
 src/driver/driver_api.cc                           | 31 +++++++++++-----------
 src/te/schedule/schedule_postproc_to_primfunc.cc   |  4 ++-
 src/tir/transforms/compact_buffer_region.cc        | 16 +++++++----
 src/tir/transforms/convert_blocks_to_opaque.cc     | 13 ++++++---
 src/tir/transforms/flatten_buffer.cc               | 12 ++++++---
 src/tir/transforms/inject_prefetch.cc              | 13 ++++++---
 src/tir/transforms/ir_utils.cc                     |  5 ++++
 src/tir/transforms/ir_utils.h                      | 11 ++++++++
 src/tir/transforms/lower_init_block.cc             | 13 ++++++---
 .../plan_update_buffer_allocation_location.cc      | 15 ++++++++---
 src/tir/transforms/storage_flatten.cc              | 20 +++++++++-----
 tests/python/unittest/test_lower_build.py          | 27 ++++++++++++++++---
 .../test_tir_transform_compact_buffer_region.py    | 11 +++++++-
 .../test_tir_transform_convert_blocks_to_opaque.py | 11 +++++++-
 .../unittest/test_tir_transform_flatten_buffer.py  | 11 +++++++-
 .../test_tir_transform_lower_init_block.py         | 11 +++++++-
 ...sform_plan_update_buffer_allocation_location.py | 13 ++++++++-
 .../unittest/test_tir_transform_storage_flatten.py | 21 ++++++++++++++-
 18 files changed, 204 insertions(+), 54 deletions(-)

diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 2008fe5..d6af993 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -167,7 +167,7 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool 
for_te_schedule) {
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
   transform::PassContext pass_ctx = transform::PassContext::Current();
 
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", 
Bool(false)).value();
@@ -214,17 +214,14 @@ Array<tvm::transform::Pass> CreatePassList(bool 
disable_loop_partition, bool for
   Array<tvm::transform::Pass> pass_list = user_lower_phase0;
 
   // PHASE 1
-  if (for_te_schedule) {
-    pass_list.push_back(tir::transform::InjectPrefetch());
-    pass_list.push_back(tir::transform::StorageFlatten(64, 
instrument_bound_checkers));
-  } else {
-    pass_list.push_back(tir::transform::LowerInitBlock());
-    
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
-    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
-    pass_list.push_back(tir::transform::CompactBufferAllocation());
-    pass_list.push_back(tir::transform::LowerMatchBuffer());
-    pass_list.push_back(tir::transform::FlattenBuffer());
-  }
+  pass_list.push_back(tir::transform::InjectPrefetch());
+  pass_list.push_back(tir::transform::StorageFlatten(64, 
instrument_bound_checkers));
+  pass_list.push_back(tir::transform::LowerInitBlock());
+  pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+  pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+  pass_list.push_back(tir::transform::CompactBufferAllocation());
+  pass_list.push_back(tir::transform::LowerMatchBuffer());
+  pass_list.push_back(tir::transform::FlattenBuffer());
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
@@ -288,6 +285,10 @@ IRModule ScheduleToModule(te::Schedule sch, const 
Array<ObjectRef>& args, const
   tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, 
std::move(stmt), out_binds);
   f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
 
+  // Mark this schedule as being converted from an TE schedule. Makes sure that
+  // the correct TE passes are run.
+  f = WithAttr(std::move(f), "from_legacy_te_schedule", Bool(true));
+
   bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
 
   if (noalias) {
@@ -311,7 +312,7 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")
     });
 
 IRModule LowerModule(IRModule mod, bool simple_mode) {
-  Array<transform::Pass> pass_list = CreatePassList(simple_mode, false);
+  Array<transform::Pass> pass_list = CreatePassList(simple_mode);
   return LowerWithPassList(std::move(mod), pass_list);
 }
 
@@ -331,7 +332,7 @@ IRModule LowerPrimFunc(tir::PrimFunc func, const 
std::string& name, bool simple_
   IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
 
   // Get the pass list
-  Array<transform::Pass> pass_list = CreatePassList(simple_mode, false);
+  Array<transform::Pass> pass_list = CreatePassList(simple_mode);
   return LowerWithPassList(std::move(mod), pass_list);
 }
 
@@ -353,7 +354,7 @@ IRModule LowerSchedule(te::Schedule sch, const 
Array<ObjectRef>& args, const std
                        const std::unordered_map<te::Tensor, tir::Buffer>& 
binds, bool simple_mode) {
   IRModule mod = ScheduleToModule(std::move(sch), args, name, binds);
   // Get the legacy TE pass list
-  Array<transform::Pass> pass_list = CreatePassList(simple_mode, true);
+  Array<transform::Pass> pass_list = CreatePassList(simple_mode);
   return LowerWithPassList(mod, pass_list);
 }
 
diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc 
b/src/te/schedule/schedule_postproc_to_primfunc.cc
index 2063fc7..439d0ff 100644
--- a/src/te/schedule/schedule_postproc_to_primfunc.cc
+++ b/src/te/schedule/schedule_postproc_to_primfunc.cc
@@ -170,7 +170,9 @@ PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> 
arg_list, Stmt body,
   }
 
   body = TensorToBufferMapper(std::move(extern_buffer))(std::move(body));
-  return tir::PrimFunc(params, body, VoidType(), buffer_map);
+  // We mark this PrimFunc as coming from a TE schedule
+  return WithAttr(tir::PrimFunc(params, body, VoidType(), buffer_map), 
"from_legacy_te_schedule",
+                  Bool(true));
 }
 
 TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc")
diff --git a/src/tir/transforms/compact_buffer_region.cc 
b/src/tir/transforms/compact_buffer_region.cc
index bd1fa9b..b1a4fd4 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -32,6 +32,7 @@
 #include "../../support/arena.h"
 #include "../../support/utils.h"
 #include "../schedule/utils.h"
+#include "ir_utils.h"
 
 namespace tvm {
 namespace tir {
@@ -452,11 +453,16 @@ class BufferCompactor : public StmtExprMutator {
 };
 
 PrimFunc CompactBufferAllocation(PrimFunc f) {
-  PrimFuncNode* fptr = f.CopyOnWrite();
-  std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> region =
-      BufferAccessRegionCollector::Collect(f);
-  fptr->body = BufferCompactor::Compact(f, region);
-  return f;
+  // Only apply this pass to TIR that is not from TE schedules
+  if (!IsFromLegacyTESchedule(f)) {
+    PrimFuncNode* fptr = f.CopyOnWrite();
+    std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> region =
+        BufferAccessRegionCollector::Collect(f);
+    fptr->body = BufferCompactor::Compact(f, region);
+    return f;
+  } else {
+    return f;
+  }
 }
 
 namespace transform {
diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc 
b/src/tir/transforms/convert_blocks_to_opaque.cc
index 4c5e1dd..f7629d1 100644
--- a/src/tir/transforms/convert_blocks_to_opaque.cc
+++ b/src/tir/transforms/convert_blocks_to_opaque.cc
@@ -25,6 +25,8 @@
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
+#include "ir_utils.h"
+
 namespace tvm {
 namespace tir {
 
@@ -83,9 +85,14 @@ class OpaqueBlockConverter : public StmtExprMutator {
 };
 
 PrimFunc ConvertBlocksToOpaque(PrimFunc f) {
-  PrimFuncNode* fptr = f.CopyOnWrite();
-  fptr->body = OpaqueBlockConverter::Substitute(f);
-  return f;
+  // Only apply this pass to TIR that is not from TE schedules
+  if (!IsFromLegacyTESchedule(f)) {
+    PrimFuncNode* fptr = f.CopyOnWrite();
+    fptr->body = OpaqueBlockConverter::Substitute(f);
+    return f;
+  } else {
+    return f;
+  }
 }
 
 namespace transform {
diff --git a/src/tir/transforms/flatten_buffer.cc 
b/src/tir/transforms/flatten_buffer.cc
index f1f914f..85c4123 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -28,6 +28,7 @@
 #include <tvm/tir/transform.h>
 
 #include "../../support/utils.h"
+#include "ir_utils.h"
 
 namespace tvm {
 namespace tir {
@@ -151,9 +152,14 @@ class BufferFlattener : public StmtExprMutator {
 };
 
 PrimFunc FlattenBuffer(PrimFunc f) {
-  PrimFuncNode* fptr = f.CopyOnWrite();
-  fptr->body = BufferFlattener::Flatten(f);
-  return f;
+  // Only apply this pass to TIR that is not from TE schedules
+  if (!IsFromLegacyTESchedule(f)) {
+    PrimFuncNode* fptr = f.CopyOnWrite();
+    fptr->body = BufferFlattener::Flatten(f);
+    return f;
+  } else {
+    return f;
+  }
 }
 
 namespace transform {
diff --git a/src/tir/transforms/inject_prefetch.cc 
b/src/tir/transforms/inject_prefetch.cc
index 4ce9c76..f20577e 100644
--- a/src/tir/transforms/inject_prefetch.cc
+++ b/src/tir/transforms/inject_prefetch.cc
@@ -31,6 +31,8 @@
 
 #include <unordered_set>
 
+#include "ir_utils.h"
+
 namespace tvm {
 namespace tir {
 
@@ -96,9 +98,14 @@ namespace transform {
 
 Pass InjectPrefetch() {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
-    auto* n = f.CopyOnWrite();
-    n->body = PrefetchInjector()(std::move(n->body));
-    return f;
+    // Only apply this pass to TIR from TE schedules
+    if (IsFromLegacyTESchedule(f)) {
+      auto* n = f.CopyOnWrite();
+      n->body = PrefetchInjector()(std::move(n->body));
+      return f;
+    } else {
+      return f;
+    }
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.InjectPrefetch", {});
 }
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index 7248bd4..a41905c 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -244,5 +244,10 @@ Region ConvertRegion(const MatchBufferRegion& 
match_buffer, const Region& region
   return result;
 }
 
+Bool IsFromLegacyTESchedule(PrimFunc f) {
+  Optional<Bool> from_legacy_te_schedule = 
f->GetAttr("from_legacy_te_schedule", Bool(false));
+  return from_legacy_te_schedule.value();
+}
+
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h
index 79c5f06..9be18b7 100644
--- a/src/tir/transforms/ir_utils.h
+++ b/src/tir/transforms/ir_utils.h
@@ -27,6 +27,7 @@
 #include <tvm/runtime/device_api.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
 #include <tvm/tir/op.h>
 
 #include <limits>
@@ -213,6 +214,16 @@ Array<PrimExpr> ConvertIndices(const MatchBufferRegion& 
match_buffer,
  */
 Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& 
region);
 
+/*!
+ * \brief Check if a given PrimFunc originated from a TE schedule.
+ *
+ * Internally this checks for the `from_legacy_te_schedule` attr of the 
PrimFunc.
+ *
+ * \param f PrimFunc to check
+ * \return Whether or not the PrimFunc was created from a te schedule
+ */
+Bool IsFromLegacyTESchedule(PrimFunc f);
+
 }  // namespace tir
 }  // namespace tvm
 #endif  // TVM_TIR_TRANSFORMS_IR_UTILS_H_
diff --git a/src/tir/transforms/lower_init_block.cc 
b/src/tir/transforms/lower_init_block.cc
index c8aca51..d8621ac 100644
--- a/src/tir/transforms/lower_init_block.cc
+++ b/src/tir/transforms/lower_init_block.cc
@@ -25,6 +25,8 @@
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
+#include "ir_utils.h"
+
 namespace tvm {
 namespace tir {
 
@@ -63,9 +65,14 @@ class InitBlockLower : public StmtMutator {
 };
 
 PrimFunc LowerInitBlock(PrimFunc func) {
-  auto fptr = func.CopyOnWrite();
-  fptr->body = InitBlockLower()(std::move(fptr->body));
-  return func;
+  // Only apply this pass to TIR that is not from TE schedules
+  if (!IsFromLegacyTESchedule(func)) {
+    auto fptr = func.CopyOnWrite();
+    fptr->body = InitBlockLower()(std::move(fptr->body));
+    return func;
+  } else {
+    return func;
+  }
 }
 
 namespace transform {
diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc 
b/src/tir/transforms/plan_update_buffer_allocation_location.cc
index 949c955..bee11ad 100644
--- a/src/tir/transforms/plan_update_buffer_allocation_location.cc
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -26,6 +26,8 @@
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
+#include "ir_utils.h"
+
 namespace tvm {
 namespace tir {
 
@@ -145,10 +147,15 @@ class BufferAllocationLocator : public StmtExprMutator {
 };
 
 PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
-  auto fptr = func.CopyOnWrite();
-  BufferAllocationLocator locator(func);
-  fptr->body = locator(fptr->body);
-  return func;
+  // Only apply this pass to TIR that is not from TE schedules
+  if (!IsFromLegacyTESchedule(func)) {
+    auto fptr = func.CopyOnWrite();
+    BufferAllocationLocator locator(func);
+    fptr->body = locator(fptr->body);
+    return func;
+  } else {
+    return func;
+  }
 }
 
 namespace transform {
diff --git a/src/tir/transforms/storage_flatten.cc 
b/src/tir/transforms/storage_flatten.cc
index 38b3a77..2c32cc7 100644
--- a/src/tir/transforms/storage_flatten.cc
+++ b/src/tir/transforms/storage_flatten.cc
@@ -500,13 +500,19 @@ class StorageFlattener : public StmtExprMutator {
 };
 
 PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool 
create_bound_attributes) {
-  auto fptr = func.CopyOnWrite();
-
-  IRVisitorWithAnalyzer bound_analyzer;
-  bound_analyzer(fptr->body);
-  fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, 
create_bound_attributes,
-                                &bound_analyzer)(std::move(fptr->body));
-  return func;
+  // Only apply this pass to TIR from TE schedules
+  Optional<Bool> from_legacy_te_schedule = 
func->GetAttr("from_legacy_te_schedule", Bool(false));
+  if (from_legacy_te_schedule.value()) {
+    auto fptr = func.CopyOnWrite();
+
+    IRVisitorWithAnalyzer bound_analyzer;
+    bound_analyzer(fptr->body);
+    fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, 
create_bound_attributes,
+                                  &bound_analyzer)(std::move(fptr->body));
+    return func;
+  } else {
+    return func;
+  }
 }
 
 namespace transform {
diff --git a/tests/python/unittest/test_lower_build.py 
b/tests/python/unittest/test_lower_build.py
index 4505a7b..e5528a8 100644
--- a/tests/python/unittest/test_lower_build.py
+++ b/tests/python/unittest/test_lower_build.py
@@ -52,6 +52,25 @@ def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
 class LoweredModule:
     def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
         # function attr dict
+        tir.func_attr(
+            {"global_symbol": "main", "from_legacy_te_schedule": True, 
"tir.noalias": True}
+        )
+        A = tir.match_buffer(a, [128, 128])
+        B = tir.match_buffer(b, [128, 128])
+        C = tir.match_buffer(c, [128, 128])
+        # body
+        for x, y in tir.grid(128, 128):
+            C.data[x * 128 + y] = 0.0
+            for k in tir.serial(0, 128):
+                C.data[x * 128 + y] = tir.load("float32", C.data, x * 128 + y) 
+ tir.load(
+                    "float32", A.data, x * 128 + k
+                ) * tir.load("float32", B.data, y * 128 + k)
+
+
[email protected]
+class LoweredTIRModule:
+    def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+        # function attr dict
         tir.func_attr({"global_symbol": "main", "tir.noalias": True})
         A = tir.match_buffer(a, [128, 128])
         B = tir.match_buffer(b, [128, 128])
@@ -83,7 +102,7 @@ def test_lower_build_te_schedule():
 def test_lower_build_tir_func():
     # check lowering
     ir_mod = tvm.lower(matmul)
-    tvm.ir.assert_structural_equal(ir_mod, LoweredModule())
+    tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule())
     # check building
     mod = tvm.build(matmul, target="llvm")
     _check_module_with_numpy(mod)
@@ -95,7 +114,7 @@ def test_lower_build_tir_module():
     ir_mod = IRModule({"main": func})
     # check lowering
     lowered_mod = tvm.lower(ir_mod)
-    tvm.ir.assert_structural_equal(lowered_mod, LoweredModule())
+    tvm.ir.assert_structural_equal(lowered_mod, LoweredTIRModule())
     # check building
     mod = tvm.build(ir_mod, target="llvm")
     _check_module_with_numpy(mod)
@@ -103,8 +122,8 @@ def test_lower_build_tir_module():
 
 def test_lower_build_lowered_module():
     # check lowering
-    ir_mod = tvm.lower(LoweredModule())
-    tvm.ir.assert_structural_equal(ir_mod, LoweredModule())
+    ir_mod = tvm.lower(LoweredTIRModule())
+    tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule())
     # check building
     mod = tvm.build(ir_mod, target="llvm")
     _check_module_with_numpy(mod)
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py 
b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
index a469c6d..fb53b42 100644
--- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import tir
+from tvm import tir, te
 from tvm.script import ty
 
 
@@ -371,6 +371,15 @@ def test_match_buffer():
     _check(match_buffer_func, compacted_match_buffer_func)
 
 
+def test_lower_te():
+    x = te.placeholder((1,))
+    y = te.compute((1,), lambda i: x[i] + 2)
+    s = te.create_schedule(y.op)
+    orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
+    mod = tvm.tir.transform.CompactBufferAllocation()(orig_mod)
+    tvm.ir.assert_structural_equal(mod, orig_mod)  # CompactBufferAllocation 
should do nothing on TE
+
+
 if __name__ == "__main__":
     test_elementwise()
     test_unschedulable_block()
diff --git 
a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py 
b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py
index 38fe1c9..708f1af 100644
--- a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py
+++ b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import tir
+from tvm import tir, te
 from tvm.script import ty
 
 
@@ -73,5 +73,14 @@ def test_elementwise():
     _check(elementwise_func, substituted_elementwise_func)
 
 
+def test_lower_te():
+    x = te.placeholder((1,))
+    y = te.compute((1,), lambda i: x[i] + 2)
+    s = te.create_schedule(y.op)
+    orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
+    mod = tvm.tir.transform.ConvertBlocksToOpaque()(orig_mod)
+    tvm.ir.assert_structural_equal(mod, orig_mod)  # ConvertBlocksToOpaque 
should do nothing on TE
+
+
 if __name__ == "__main__":
     test_elementwise()
diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py 
b/tests/python/unittest/test_tir_transform_flatten_buffer.py
index 6929a32..3b2b3cf 100644
--- a/tests/python/unittest/test_tir_transform_flatten_buffer.py
+++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import tir
+from tvm import tir, te
 from tvm.script import ty
 
 
@@ -234,6 +234,15 @@ def test_multi_alloc():
     _check(compacted_multi_alloc_func, flattened_multi_alloc_func)
 
 
+def test_lower_te():
+    x = te.placeholder((1,))
+    y = te.compute((1,), lambda i: x[i] + 2)
+    s = te.create_schedule(y.op)
+    orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
+    mod = tvm.tir.transform.FlattenBuffer()(orig_mod)
+    tvm.ir.assert_structural_equal(mod, orig_mod)  # FlattenBuffer should do 
nothing on TE
+
+
 if __name__ == "__main__":
     test_elementwise()
     test_gpu_workload()
diff --git a/tests/python/unittest/test_tir_transform_lower_init_block.py 
b/tests/python/unittest/test_tir_transform_lower_init_block.py
index badf5e0..8499c93 100644
--- a/tests/python/unittest/test_tir_transform_lower_init_block.py
+++ b/tests/python/unittest/test_tir_transform_lower_init_block.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import tir
+from tvm import tir, te
 from tvm.script import ty
 
 # pylint: disable=no-self-argument
@@ -85,6 +85,15 @@ def test_lower_match_buffer():
     tvm.ir.assert_structural_equal(mod, BranchWithMatchBuffer(), True)
 
 
+def test_lower_te():
+    x = te.placeholder((1,))
+    y = te.compute((1,), lambda i: x[i] + 2)
+    s = te.create_schedule(y.op)
+    orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
+    mod = tvm.tir.transform.LowerInitBlock()(orig_mod)
+    tvm.ir.assert_structural_equal(mod, orig_mod)  # LowerInitBlock should do 
nothing on TE
+
+
 if __name__ == "__main__":
     test_lower_reduction()
     test_lower_match_buffer()
diff --git 
a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
 
b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
index 022c964..72a2f5e 100644
--- 
a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
+++ 
b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import tir
+from tvm import tir, te
 from tvm.script import ty
 
 
@@ -149,6 +149,17 @@ def test_match_buffer_allocation():
     _check(match_buffer_func, transformed_match_buffer_func)
 
 
+def test_lower_te():
+    x = te.placeholder((1,))
+    y = te.compute((1,), lambda i: x[i] + 2)
+    s = te.create_schedule(y.op)
+    orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
+    mod = tvm.tir.transform.PlanAndUpdateBufferAllocationLocation()(orig_mod)
+    tvm.ir.assert_structural_equal(
+        mod, orig_mod
+    )  # PlanAndUpdateBufferAllocationLocation should do nothing on TE
+
+
 if __name__ == "__main__":
     test_elementwise()
     test_locate_buffer_allocation()
diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py 
b/tests/python/unittest/test_tir_transform_storage_flatten.py
index 0e9ab86..1dd4a48 100644
--- a/tests/python/unittest/test_tir_transform_storage_flatten.py
+++ b/tests/python/unittest/test_tir_transform_storage_flatten.py
@@ -16,6 +16,8 @@
 # under the License.
 import tvm
 from tvm import te
+from tvm.script import ty
+from tvm.relay import GlobalVar
 
 
 def test_flatten2():
@@ -102,7 +104,9 @@ def test_flatten_double_buffer():
 
     stmt = ib.get()
 
-    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C], stmt))
+    mod = tvm.IRModule.from_expr(
+        tvm.tir.PrimFunc([A, C], stmt).with_attr("from_legacy_te_schedule", 
True)
+    )
 
     with tvm.transform.PassContext(config={"tir.InjectDoubleBuffer": 
{"split_loop": 2}}):
         mod = tvm.transform.Sequential(
@@ -130,6 +134,21 @@ def test_flatten_double_buffer():
     assert count[0] == 4
 
 
[email protected]
+def tir_func(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, [2, 2])
+    B = tir.match_buffer(a, [2, 2])
+    A[0, 1] = B[1, 1]
+
+
+def test_flatten_tir():
+    orig_mod = tvm.IRModule({GlobalVar("main"): tir_func})
+    mod = tvm.tir.transform.StorageFlatten(64)(orig_mod)
+    tvm.ir.assert_structural_equal(
+        orig_mod, mod
+    )  # StorageFlatten should do nothing to TIR functions
+
+
 if __name__ == "__main__":
     test_flatten2()
     test_flatten_storage_align()

Reply via email to