This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dcc8891d42 [Refactor] Remove legacy TE schedule tag (#17701)
dcc8891d42 is described below
commit dcc8891d42c4da84477b75d7234912bbaecd8b61
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Mar 4 00:19:27 2025 +0800
[Refactor] Remove legacy TE schedule tag (#17701)
This commit removes the TE schedule tag from TIR transforms as all legacy
TE schedules have been removed.
---
src/tir/transforms/compact_buffer_region.cc | 15 +++++----------
src/tir/transforms/convert_blocks_to_opaque.cc | 11 +++--------
src/tir/transforms/flatten_buffer.cc | 9 +--------
src/tir/transforms/ir_utils.cc | 5 -----
src/tir/transforms/ir_utils.h | 10 ----------
src/tir/transforms/lift_thread_binding.cc | 11 +++--------
src/tir/transforms/lower_cross_thread_reduction.cc | 11 +++--------
src/tir/transforms/lower_init_block.cc | 11 +++--------
src/tir/transforms/lower_opaque_block.cc | 11 +++--------
.../transforms/plan_update_buffer_allocation_location.cc | 13 ++++---------
src/tir/transforms/unify_thread_binding.cc | 11 +++--------
.../test_tir_transform_inject_rolling_buffer.py | 16 ++--------------
.../tir-transform/test_tir_transform_loop_partition.py | 2 +-
tests/python/tvmscript/test_tvmscript_roundtrip.py | 6 +++---
14 files changed, 34 insertions(+), 108 deletions(-)
diff --git a/src/tir/transforms/compact_buffer_region.cc
b/src/tir/transforms/compact_buffer_region.cc
index 7385af4952..1907c7ca50 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -741,16 +741,11 @@ Stmt BufferCompactorCompact(
}
PrimFunc CompactBufferAllocation(PrimFunc f, bool is_strict) {
- // Only apply this pass to TIR that is not from TE schedules
- if (!IsFromLegacyTESchedule(f)) {
- PrimFuncNode* fptr = f.CopyOnWrite();
- auto region = BufferAccessRegionCollector::Collect(f,
/*collect_inbound=*/is_strict);
- auto storage_align = CollectStorageAlignAnnotation(f->body);
- fptr->body = BufferCompactorCompact(f, region, storage_align);
- return f;
- } else {
- return f;
- }
+ PrimFuncNode* fptr = f.CopyOnWrite();
+ auto region = BufferAccessRegionCollector::Collect(f,
/*collect_inbound=*/is_strict);
+ auto storage_align = CollectStorageAlignAnnotation(f->body);
+ fptr->body = BufferCompactorCompact(f, region, storage_align);
+ return f;
}
namespace transform {
diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc
b/src/tir/transforms/convert_blocks_to_opaque.cc
index 9564871349..ab8d98a00e 100644
--- a/src/tir/transforms/convert_blocks_to_opaque.cc
+++ b/src/tir/transforms/convert_blocks_to_opaque.cc
@@ -108,14 +108,9 @@ class OpaqueBlockConverter : public StmtExprMutator {
};
PrimFunc ConvertBlocksToOpaque(PrimFunc f) {
- // Only apply this pass to TIR that is not from TE schedules
- if (!IsFromLegacyTESchedule(f)) {
- PrimFuncNode* fptr = f.CopyOnWrite();
- fptr->body = OpaqueBlockConverter::Substitute(f);
- return f;
- } else {
- return f;
- }
+ PrimFuncNode* fptr = f.CopyOnWrite();
+ fptr->body = OpaqueBlockConverter::Substitute(f);
+ return f;
}
namespace transform {
diff --git a/src/tir/transforms/flatten_buffer.cc
b/src/tir/transforms/flatten_buffer.cc
index c04e12b839..a6da7f7fc4 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -268,14 +268,7 @@ class BufferFlattener : public
arith::IRMutatorWithAnalyzer {
Map<Var, Buffer> updated_extern_buffer_map_;
};
-PrimFunc FlattenBuffer(PrimFunc f) {
- // Only apply this pass to TIR that is not from TE schedules
- if (!IsFromLegacyTESchedule(f)) {
- return BufferFlattener::Flatten(f);
- } else {
- return f;
- }
-}
+PrimFunc FlattenBuffer(PrimFunc f) { return BufferFlattener::Flatten(f); }
namespace transform {
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index 7026215a01..b63192f50a 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -594,11 +594,6 @@ Region ConvertRegion(const MatchBufferRegion&
match_buffer, const Region& region
return result;
}
-Bool IsFromLegacyTESchedule(PrimFunc f) {
- Optional<Bool> from_legacy_te_schedule =
f->GetAttr("from_legacy_te_schedule", Bool(false));
- return from_legacy_te_schedule.value();
-}
-
Optional<arith::IntConstraints> ConditionalBoundsContext::TrySolveCondition() {
// extract equations and related vars from condition expression.
// currently only extract simple integral equations which could be solvable.
diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h
index 05345aab86..94054f5d2c 100644
--- a/src/tir/transforms/ir_utils.h
+++ b/src/tir/transforms/ir_utils.h
@@ -234,16 +234,6 @@ Region ConvertRegion(const MatchBufferRegion&
match_buffer, const Region& region
*/
Array<PrimExpr> GetBufferAllocationShape(const Buffer& buffer);
-/*!
- * \brief Check if a given PrimFunc originated from a TE schedule.
- *
- * Internally this checks for the `from_legacy_te_schedule` attr of the
PrimFunc.
- *
- * \param f PrimFunc to check
- * \return Whether or not the PrimFunc was created from a te schedule
- */
-Bool IsFromLegacyTESchedule(PrimFunc f);
-
/*!
* \brief Context helper to update domain map within conditional scope.
* Assume the condition is `0 <= i && i < 9` and domain of i is [0, 20], Then
diff --git a/src/tir/transforms/lift_thread_binding.cc
b/src/tir/transforms/lift_thread_binding.cc
index 9d7d455dba..8cb88fa653 100644
--- a/src/tir/transforms/lift_thread_binding.cc
+++ b/src/tir/transforms/lift_thread_binding.cc
@@ -169,14 +169,9 @@ class ThreadBindingLifter : public StmtExprMutator {
};
PrimFunc LiftThreadBinding(PrimFunc f) {
- // Only apply this pass to TIR that is not from TE schedules
- if (!IsFromLegacyTESchedule(f)) {
- PrimFuncNode* fptr = f.CopyOnWrite();
- fptr->body = ThreadBindingLifter()(std::move(fptr->body));
- return f;
- } else {
- return f;
- }
+ PrimFuncNode* fptr = f.CopyOnWrite();
+ fptr->body = ThreadBindingLifter()(std::move(fptr->body));
+ return f;
}
namespace transform {
diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc
b/src/tir/transforms/lower_cross_thread_reduction.cc
index 0146e2aebf..325d8e5bb5 100644
--- a/src/tir/transforms/lower_cross_thread_reduction.cc
+++ b/src/tir/transforms/lower_cross_thread_reduction.cc
@@ -920,14 +920,9 @@ class CrossThreadReductionTransformer : public StmtMutator
{
};
PrimFunc LowerCrossThreadReduction(PrimFunc f) {
- // Only apply this pass to TIR that is not from TE schedules
- if (!IsFromLegacyTESchedule(f)) {
- PrimFuncNode* fptr = f.CopyOnWrite();
- fptr->body = CrossThreadReductionTransformer()(f->body);
- return f;
- } else {
- return f;
- }
+ PrimFuncNode* fptr = f.CopyOnWrite();
+ fptr->body = CrossThreadReductionTransformer()(f->body);
+ return f;
}
namespace transform {
diff --git a/src/tir/transforms/lower_init_block.cc
b/src/tir/transforms/lower_init_block.cc
index 17b4e3fb22..3e8fc20431 100644
--- a/src/tir/transforms/lower_init_block.cc
+++ b/src/tir/transforms/lower_init_block.cc
@@ -65,14 +65,9 @@ class InitBlockLower : public StmtMutator {
};
PrimFunc LowerInitBlock(PrimFunc func) {
- // Only apply this pass to TIR that is not from TE schedules
- if (!IsFromLegacyTESchedule(func)) {
- auto fptr = func.CopyOnWrite();
- fptr->body = InitBlockLower()(std::move(fptr->body));
- return func;
- } else {
- return func;
- }
+ auto fptr = func.CopyOnWrite();
+ fptr->body = InitBlockLower()(std::move(fptr->body));
+ return func;
}
namespace transform {
diff --git a/src/tir/transforms/lower_opaque_block.cc
b/src/tir/transforms/lower_opaque_block.cc
index 08642a598b..96c6d3759c 100644
--- a/src/tir/transforms/lower_opaque_block.cc
+++ b/src/tir/transforms/lower_opaque_block.cc
@@ -200,14 +200,9 @@ class OpaqueBlockLower : public StmtExprMutator {
};
PrimFunc LowerOpaqueBlock(PrimFunc f) {
- // Only apply this pass to TIR that is not from TE schedules
- if (!IsFromLegacyTESchedule(f)) {
- auto fptr = f.CopyOnWrite();
- fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body));
- return f;
- } else {
- return f;
- }
+ auto fptr = f.CopyOnWrite();
+ fptr->body = OpaqueBlockLower::Rewrite(std::move(fptr->body));
+ return f;
}
namespace transform {
diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc
b/src/tir/transforms/plan_update_buffer_allocation_location.cc
index f9ce708c78..5ce8ade208 100644
--- a/src/tir/transforms/plan_update_buffer_allocation_location.cc
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -242,15 +242,10 @@ class BufferAllocationLocator : public StmtExprMutator {
};
PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
- // Only apply this pass to TIR that is not from TE schedules
- if (!IsFromLegacyTESchedule(func)) {
- auto fptr = func.CopyOnWrite();
- BufferAllocationLocator locator(func);
- fptr->body = locator(fptr->body);
- return func;
- } else {
- return func;
- }
+ auto fptr = func.CopyOnWrite();
+ BufferAllocationLocator locator(func);
+ fptr->body = locator(fptr->body);
+ return func;
}
namespace transform {
diff --git a/src/tir/transforms/unify_thread_binding.cc
b/src/tir/transforms/unify_thread_binding.cc
index 02fa333dbe..67c7f05ff4 100644
--- a/src/tir/transforms/unify_thread_binding.cc
+++ b/src/tir/transforms/unify_thread_binding.cc
@@ -185,14 +185,9 @@ class ThreadBindingUnifier : public StmtExprMutator {
};
PrimFunc UnifyThreadBinding(PrimFunc f) {
- // Only apply this pass to TIR that is not from TE schedules
- if (!IsFromLegacyTESchedule(f)) {
- PrimFuncNode* fptr = f.CopyOnWrite();
- fptr->body = ThreadBindingUnifier::Unify(std::move(f->body));
- return f;
- } else {
- return f;
- }
+ PrimFuncNode* fptr = f.CopyOnWrite();
+ fptr->body = ThreadBindingUnifier::Unify(std::move(f->body));
+ return f;
}
namespace transform {
diff --git
a/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py
b/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py
index 3d8f85bf79..4dd1380c8f 100644
--- a/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py
+++ b/tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py
@@ -37,13 +37,7 @@ class PreRollingBuffer:
),
) -> None:
# function attr dict
- T.func_attr(
- {
- "from_legacy_te_schedule": True,
- "global_symbol": "main",
- "tir.noalias": True,
- }
- )
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
A_1 = T.match_buffer(
A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=64,
offset_factor=1
)
@@ -112,13 +106,7 @@ class PostRollingBuffer:
),
) -> None:
# function attr dict
- T.func_attr(
- {
- "from_legacy_te_schedule": True,
- "global_symbol": "main",
- "tir.noalias": True,
- }
- )
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
A_1 = T.match_buffer(
A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=64,
offset_factor=1
)
diff --git a/tests/python/tir-transform/test_tir_transform_loop_partition.py
b/tests/python/tir-transform/test_tir_transform_loop_partition.py
index 25660880e1..1e079ada55 100644
--- a/tests/python/tir-transform/test_tir_transform_loop_partition.py
+++ b/tests/python/tir-transform/test_tir_transform_loop_partition.py
@@ -238,7 +238,7 @@ def test_cce_loop_3():
def partitioned_concat(
A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32"), C:
T.Buffer((32,), "float32")
) -> None:
- T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main",
"tir.noalias": True})
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i in T.serial(0, 16):
C[i] = A[i]
for i in T.serial(0, 16):
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index b44ff5ad72..f29c03c640 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -3644,7 +3644,7 @@ def let_stmt_value():
def string_stride():
@T.prim_func
def main(a: T.handle, b: T.handle):
- T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main",
"tir.noalias": True})
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
n = T.int32()
A = T.match_buffer(a, (n,), strides=("A_s0",), buffer_type="auto")
B = T.match_buffer(b, (n,), strides=("B_s0",), buffer_type="auto")
@@ -3663,7 +3663,7 @@ def string_stride():
def string_stride_int64():
@T.prim_func
def main(a: T.handle, b: T.handle):
- T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main",
"tir.noalias": True})
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
n = T.int64()
A_s0 = T.int64()
B_s0 = T.int64()
@@ -3679,7 +3679,7 @@ def merge_shape_var_def():
# uninitialized vars
@T.prim_func(check_well_formed=False)
def main(A: T.handle, B: T.handle):
- T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main",
"tir.noalias": True})
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
m, n = T.int32(), T.int32()
A_1 = T.match_buffer(A, (m, n), strides=("A_1_s0", "A_1_s1"),
buffer_type="auto")
B_1 = T.match_buffer(B, (m, n), strides=("B_1_s0", "B_1_s1"),
buffer_type="auto")