This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dcc8891d42 [Refactor] Remove legacy TE schedule tag (#17701)
dcc8891d42 is described below

commit dcc8891d42c4da84477b75d7234912bbaecd8b61
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Mar 4 00:19:27 2025 +0800

    [Refactor] Remove legacy TE schedule tag (#17701)
    
    This commit removes the TE schedule tag from TIR transforms as all legacy
    TE schedules have been removed.
---
 src/tir/transforms/compact_buffer_region.cc              | 15 +++++----------
 src/tir/transforms/convert_blocks_to_opaque.cc           | 11 +++--------
 src/tir/transforms/flatten_buffer.cc                     |  9 +--------
 src/tir/transforms/ir_utils.cc                           |  5 -----
 src/tir/transforms/ir_utils.h                            | 10 ----------
 src/tir/transforms/lift_thread_binding.cc                | 11 +++--------
 src/tir/transforms/lower_cross_thread_reduction.cc       | 11 +++--------
 src/tir/transforms/lower_init_block.cc                   | 11 +++--------
 src/tir/transforms/lower_opaque_block.cc                 | 11 +++--------
 .../transforms/plan_update_buffer_allocation_location.cc | 13 ++++---------
 src/tir/transforms/unify_thread_binding.cc               | 11 +++--------
 .../test_tir_transform_inject_rolling_buffer.py          | 16 ++--------------
 .../tir-transform/test_tir_transform_loop_partition.py   |  2 +-
 tests/python/tvmscript/test_tvmscript_roundtrip.py       |  6 +++---
 14 files changed, 34 insertions(+), 108 deletions(-)

diff --git a/src/tir/transforms/compact_buffer_region.cc 
b/src/tir/transforms/compact_buffer_region.cc
index 7385af4952..1907c7ca50 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -741,16 +741,11 @@ Stmt BufferCompactorCompact(
 }
 
 PrimFunc CompactBufferAllocation(PrimFunc f, bool is_strict) {
-  // Only apply this pass to TIR that is not from TE schedules
-  if (!IsFromLegacyTESchedule(f)) {
-    PrimFuncNode* fptr = f.CopyOnWrite();
-    auto region = BufferAccessRegionCollector::Collect(f, 
/*collect_inbound=*/is_strict);
-    auto storage_align = CollectStorageAlignAnnotation(f->body);
-    fptr->body = BufferCompactorCompact(f, region, storage_align);
-    return f;
-  } else {
-    return f;
-  }
+  PrimFuncNode* fptr = f.CopyOnWrite();
+  auto region = BufferAccessRegionCollector::Collect(f, 
/*collect_inbound=*/is_strict);
+  auto storage_align = CollectStorageAlignAnnotation(f->body);
+  fptr->body = BufferCompactorCompact(f, region, storage_align);
+  return f;
 }
 
 namespace transform {
diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc 
b/src/tir/transforms/convert_blocks_to_opaque.cc
index 9564871349..ab8d98a00e 100644
--- a/src/tir/transforms/convert_blocks_to_opaque.cc
+++ b/src/tir/transforms/convert_blocks_to_opaque.cc
@@ -108,14 +108,9 @@ class OpaqueBlockConverter : public StmtExprMutator {
 };
 
 PrimFunc ConvertBlocksToOpaque(PrimFunc f) {
-  // Only apply this pass to TIR that is not from TE schedules
-  if (!IsFromLegacyTESchedule(f)) {
-    PrimFuncNode* fptr = f.CopyOnWrite();
-    fptr->body = OpaqueBlockConverter::Substitute(f);
-    return f;
-  } else {
-    return f;
-  }
+  PrimFuncNode* fptr = f.CopyOnWrite();
+  fptr->body = OpaqueBlockConverter::Substitute(f);
+  return f;
 }
 
 namespace transform {
diff --git a/src/tir/transforms/flatten_buffer.cc 
b/src/tir/transforms/flatten_buffer.cc
index c04e12b839..a6da7f7fc4 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -268,14 +268,7 @@ class BufferFlattener : public 
arith::IRMutatorWithAnalyzer {
   Map<Var, Buffer> updated_extern_buffer_map_;
 };
 
-PrimFunc FlattenBuffer(PrimFunc f) {
-  // Only apply this pass to TIR that is not from TE schedules
-  if (!IsFromLegacyTESchedule(f)) {
-    return BufferFlattener::Flatten(f);
-  } else {
-    return f;
-  }
-}
+PrimFunc FlattenBuffer(PrimFunc f) { return BufferFlattener::Flatten(f); }
 
 namespace transform {
 
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index 7026215a01..b63192f50a 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -594,11 +594,6 @@ Region ConvertRegion(const MatchBufferRegion& 
match_buffer, const Region& region
   return result;
 }
 
-Bool IsFromLegacyTESchedule(PrimFunc f) {
-  Optional<Bool> from_legacy_te_schedule = 
f->GetAttr("from_legacy_te_schedule", Bool(false));
-  return from_legacy_te_schedule.value();
-}
-
 Optional<arith::IntConstraints> ConditionalBoundsContext::TrySolveCondition() {
   // extract equations and related vars from condition expression.
   // currently only extract simple integral equations which could be solvable.
diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h
index 05345aab86..94054f5d2c 100644
--- a/src/tir/transforms/ir_utils.h
+++ b/src/tir/transforms/ir_utils.h
@@ -234,16 +234,6 @@ Region ConvertRegion(const MatchBufferRegion& 
match_buffer, const Region& region
  */
 Array<PrimExpr> GetBufferAllocationShape(const Buffer& buffer);
 
-/*!
- * \brief Check if a given PrimFunc originated from a TE schedule.
- *
- * Internally this checks for the `from_legacy_te_schedule` attr of the 
PrimFunc.
- *
- * \param f PrimFunc to check
- * \return Whether or not the PrimFunc was created from a te schedule
- */
-Bool IsFromLegacyTESchedule(PrimFunc f);
-
 /*!
  * \brief Context helper to update domain map within conditional scope.
  * Assume the condition is `0 <= i && i < 9` and domain of i is [0, 20], Then
diff --git a/src/tir/transforms/lift_thread_binding.cc 
b/src/tir/transforms/lift_thread_binding.cc
index 9d7d455dba..8cb88fa653 100644
--- a/src/tir/transforms/lift_thread_binding.cc
+++ b/src/tir/transforms/lift_thread_binding.cc
@@ -169,14 +169,9 @@ class ThreadBindingLifter : public StmtExprMutator {
 };
 
 PrimFunc LiftThreadBinding(PrimFunc f) {
-  // Only apply this pass to TIR that is not from TE schedules
-  if (!IsFromLegacyTESchedule(f)) {
-    PrimFuncNode* fptr = f.CopyOnWrite();
-    fptr->body = ThreadBindingLifter()(std::move(fptr->body));
-    return f;
-  } else {
-    return f;
-  }
+  PrimFuncNode* fptr = f.CopyOnWrite();
+  fptr->body = ThreadBindingLifter()(std::move(fptr->body));
+  return f;
 }
 
 namespace transform {
diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc 
b/src/tir/transforms/lower_cross_thread_reduction.cc
index 0146e2aebf..325d8e5bb5 100644
--- a/src/tir/transforms/lower_cross_thread_reduction.cc
+++ b/src/tir/transforms/lower_cross_thread_reduction.cc
@@ -920,14 +920,9 @@ class CrossThreadReductionTransformer : public StmtMutator 
{
 };
 
 PrimFunc LowerCrossThreadReduction(PrimFunc f) {
-  // Only apply this pass to TIR that is not from TE schedules
-  if (!IsFromLegacyTESchedule(f)) {
-    PrimFuncNode* fptr = f.CopyOnWrite();
-    fptr->body = CrossThreadReductionTransformer()(f->body);
-    return f;
-  } else {
-    return f;
-  }
+  PrimFuncNode* fptr = f.CopyOnWrite();
+  fptr->body = CrossThreadReductionTransformer()(f->body);
+  return f;
 }
 
 namespace transform {
diff --git a/src/tir/transforms/lower_init_block.cc 
b/src/tir/transforms/lower_init_block.cc
index 17b4e3fb22..3e8fc20431 100644
--- a/src/tir/transforms/lower_init_block.cc
+++ b/src/tir/transforms/lower_init_block.cc
@@ -65,14 +65,9 @@ class InitBlockLower : public StmtMutator {
 };
 
 PrimFunc LowerInitBlock(PrimFunc func) {
-  // Only apply this pass to TIR that is not from TE schedules
-  if (!IsFromLegacyTESchedule(func)) {
-    auto fptr = func.CopyOnWrite();
-    fptr->body = InitBlockLower()(std::move(fptr->body));
-    return func;
-  } else {
-    return func;
-  }
+  auto fptr = func.CopyOnWrite();
+  fptr->body = InitBlockLower()(std::move(fptr->body));
+  return func;
 }
 
 namespace transform {
diff --git a/src/tir/transforms/lower_opaque_block.cc 
b/src/tir/transforms/lower_opaque_block.cc
index 08642a598b..96c6d3759c 100644
--- a/src/tir/transforms/lower_opaque_block.cc
+++ b/src/tir/transforms/lower_opaque_block.cc
@@ -200,14 +200,9 @@ class OpaqueBlockLower : public StmtExprMutator {
 };
 
 PrimFunc LowerOpaqueBlock(PrimFunc f) {
-  // Only apply this pass to TIR that is not from TE schedules
-  if (!IsFromLegacyTESchedule(f)) {
-    auto fptr = f.CopyOnWrite();
-    fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body));
-    return f;
-  } else {
-    return f;
-  }
+  auto fptr = f.CopyOnWrite();
+  fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body));
+  return f;
 }
 
 namespace transform {
diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc 
b/src/tir/transforms/plan_update_buffer_allocation_location.cc
index f9ce708c78..5ce8ade208 100644
--- a/src/tir/transforms/plan_update_buffer_allocation_location.cc
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -242,15 +242,10 @@ class BufferAllocationLocator : public StmtExprMutator {
 };
 
 PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
-  // Only apply this pass to TIR that is not from TE schedules
-  if (!IsFromLegacyTESchedule(func)) {
-    auto fptr = func.CopyOnWrite();
-    BufferAllocationLocator locator(func);
-    fptr->body = locator(fptr->body);
-    return func;
-  } else {
-    return func;
-  }
+  auto fptr = func.CopyOnWrite();
+  BufferAllocationLocator locator(func);
+  fptr->body = locator(fptr->body);
+  return func;
 }
 
 namespace transform {
diff --git a/src/tir/transforms/unify_thread_binding.cc 
b/src/tir/transforms/unify_thread_binding.cc
index 02fa333dbe..67c7f05ff4 100644
--- a/src/tir/transforms/unify_thread_binding.cc
+++ b/src/tir/transforms/unify_thread_binding.cc
@@ -185,14 +185,9 @@ class ThreadBindingUnifier : public StmtExprMutator {
 };
 
 PrimFunc UnifyThreadBinding(PrimFunc f) {
-  // Only apply this pass to TIR that is not from TE schedules
-  if (!IsFromLegacyTESchedule(f)) {
-    PrimFuncNode* fptr = f.CopyOnWrite();
-    fptr->body = ThreadBindingUnifier::Unify(std::move(f->body));
-    return f;
-  } else {
-    return f;
-  }
+  PrimFuncNode* fptr = f.CopyOnWrite();
+  fptr->body = ThreadBindingUnifier::Unify(std::move(f->body));
+  return f;
 }
 
 namespace transform {
diff --git 
a/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py 
b/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py
index 3d8f85bf79..4dd1380c8f 100644
--- a/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py
+++ b/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py
@@ -37,13 +37,7 @@ class PreRollingBuffer:
         ),
     ) -> None:
         # function attr dict
-        T.func_attr(
-            {
-                "from_legacy_te_schedule": True,
-                "global_symbol": "main",
-                "tir.noalias": True,
-            }
-        )
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
         A_1 = T.match_buffer(
             A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=64, 
offset_factor=1
         )
@@ -112,13 +106,7 @@ class PostRollingBuffer:
         ),
     ) -> None:
         # function attr dict
-        T.func_attr(
-            {
-                "from_legacy_te_schedule": True,
-                "global_symbol": "main",
-                "tir.noalias": True,
-            }
-        )
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
         A_1 = T.match_buffer(
             A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=64, 
offset_factor=1
         )
diff --git a/tests/python/tir-transform/test_tir_transform_loop_partition.py 
b/tests/python/tir-transform/test_tir_transform_loop_partition.py
index 25660880e1..1e079ada55 100644
--- a/tests/python/tir-transform/test_tir_transform_loop_partition.py
+++ b/tests/python/tir-transform/test_tir_transform_loop_partition.py
@@ -238,7 +238,7 @@ def test_cce_loop_3():
 def partitioned_concat(
     A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32"), C: 
T.Buffer((32,), "float32")
 ) -> None:
-    T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", 
"tir.noalias": True})
+    T.func_attr({"global_symbol": "main", "tir.noalias": True})
     for i in T.serial(0, 16):
         C[i] = A[i]
     for i in T.serial(0, 16):
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py 
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index b44ff5ad72..f29c03c640 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -3644,7 +3644,7 @@ def let_stmt_value():
 def string_stride():
     @T.prim_func
     def main(a: T.handle, b: T.handle):
-        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", 
"tir.noalias": True})
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
         n = T.int32()
         A = T.match_buffer(a, (n,), strides=("A_s0",), buffer_type="auto")
         B = T.match_buffer(b, (n,), strides=("B_s0",), buffer_type="auto")
@@ -3663,7 +3663,7 @@ def string_stride():
 def string_stride_int64():
     @T.prim_func
     def main(a: T.handle, b: T.handle):
-        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", 
"tir.noalias": True})
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
         n = T.int64()
         A_s0 = T.int64()
         B_s0 = T.int64()
@@ -3679,7 +3679,7 @@ def merge_shape_var_def():
     # uninitialized vars
     @T.prim_func(check_well_formed=False)
     def main(A: T.handle, B: T.handle):
-        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", 
"tir.noalias": True})
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
         m, n = T.int32(), T.int32()
         A_1 = T.match_buffer(A, (m, n), strides=("A_1_s0", "A_1_s1"), 
buffer_type="auto")
         B_1 = T.match_buffer(B, (m, n), strides=("B_1_s0", "B_1_s1"), 
buffer_type="auto")

Reply via email to