This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8302524 [TENSORIR] Add `from_legacy_te_schdule` attr to TE PrimFuncs
(#8641)
8302524 is described below
commit 8302524a381655a727db74db97777cba99a2609e
Author: Tristan Konolige <[email protected]>
AuthorDate: Thu Aug 5 12:21:01 2021 -0700
[TENSORIR] Add `from_legacy_te_schdule` attr to TE PrimFuncs (#8641)
* [TENSORIR] Add `from_legacy_te_schdule` attr to TE PrimFuncs
The `from_legacy_te_schedule` marks PrimFuncs created from TE
scheduling. Passes that only operate on TE scheduling check this attrs
and no op if it is not found. If `from_legacy_te_schedule` is false or
not set, then it is assumed that the PrimFunc is from TensorIR. Passes
specific to TensorIR now check for the absence of this attr.
* formatting
* enable passes regardless of te or not
---
src/driver/driver_api.cc | 31 +++++++++++-----------
src/te/schedule/schedule_postproc_to_primfunc.cc | 4 ++-
src/tir/transforms/compact_buffer_region.cc | 16 +++++++----
src/tir/transforms/convert_blocks_to_opaque.cc | 13 ++++++---
src/tir/transforms/flatten_buffer.cc | 12 ++++++---
src/tir/transforms/inject_prefetch.cc | 13 ++++++---
src/tir/transforms/ir_utils.cc | 5 ++++
src/tir/transforms/ir_utils.h | 11 ++++++++
src/tir/transforms/lower_init_block.cc | 13 ++++++---
.../plan_update_buffer_allocation_location.cc | 15 ++++++++---
src/tir/transforms/storage_flatten.cc | 20 +++++++++-----
tests/python/unittest/test_lower_build.py | 27 ++++++++++++++++---
.../test_tir_transform_compact_buffer_region.py | 11 +++++++-
.../test_tir_transform_convert_blocks_to_opaque.py | 11 +++++++-
.../unittest/test_tir_transform_flatten_buffer.py | 11 +++++++-
.../test_tir_transform_lower_init_block.py | 11 +++++++-
...sform_plan_update_buffer_allocation_location.py | 13 ++++++++-
.../unittest/test_tir_transform_storage_flatten.py | 21 ++++++++++++++-
18 files changed, 204 insertions(+), 54 deletions(-)
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 2008fe5..d6af993 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -167,7 +167,7 @@ transform::Pass Filter(FCond fcond) {
return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
}
-Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool
for_te_schedule) {
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
transform::PassContext pass_ctx = transform::PassContext::Current();
bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize",
Bool(false)).value();
@@ -214,17 +214,14 @@ Array<tvm::transform::Pass> CreatePassList(bool
disable_loop_partition, bool for
Array<tvm::transform::Pass> pass_list = user_lower_phase0;
// PHASE 1
- if (for_te_schedule) {
- pass_list.push_back(tir::transform::InjectPrefetch());
- pass_list.push_back(tir::transform::StorageFlatten(64,
instrument_bound_checkers));
- } else {
- pass_list.push_back(tir::transform::LowerInitBlock());
-
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
- pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
- pass_list.push_back(tir::transform::CompactBufferAllocation());
- pass_list.push_back(tir::transform::LowerMatchBuffer());
- pass_list.push_back(tir::transform::FlattenBuffer());
- }
+ pass_list.push_back(tir::transform::InjectPrefetch());
+ pass_list.push_back(tir::transform::StorageFlatten(64,
instrument_bound_checkers));
+ pass_list.push_back(tir::transform::LowerInitBlock());
+ pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+ pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+ pass_list.push_back(tir::transform::CompactBufferAllocation());
+ pass_list.push_back(tir::transform::LowerMatchBuffer());
+ pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
@@ -288,6 +285,10 @@ IRModule ScheduleToModule(te::Schedule sch, const
Array<ObjectRef>& args, const
tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list,
std::move(stmt), out_binds);
f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+ // Mark this schedule as being converted from an TE schedule. Makes sure that
+ // the correct TE passes are run.
+ f = WithAttr(std::move(f), "from_legacy_te_schedule", Bool(true));
+
bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
if (noalias) {
@@ -311,7 +312,7 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")
});
IRModule LowerModule(IRModule mod, bool simple_mode) {
- Array<transform::Pass> pass_list = CreatePassList(simple_mode, false);
+ Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(std::move(mod), pass_list);
}
@@ -331,7 +332,7 @@ IRModule LowerPrimFunc(tir::PrimFunc func, const
std::string& name, bool simple_
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
// Get the pass list
- Array<transform::Pass> pass_list = CreatePassList(simple_mode, false);
+ Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(std::move(mod), pass_list);
}
@@ -353,7 +354,7 @@ IRModule LowerSchedule(te::Schedule sch, const
Array<ObjectRef>& args, const std
const std::unordered_map<te::Tensor, tir::Buffer>&
binds, bool simple_mode) {
IRModule mod = ScheduleToModule(std::move(sch), args, name, binds);
// Get the legacy TE pass list
- Array<transform::Pass> pass_list = CreatePassList(simple_mode, true);
+ Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(mod, pass_list);
}
diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc
b/src/te/schedule/schedule_postproc_to_primfunc.cc
index 2063fc7..439d0ff 100644
--- a/src/te/schedule/schedule_postproc_to_primfunc.cc
+++ b/src/te/schedule/schedule_postproc_to_primfunc.cc
@@ -170,7 +170,9 @@ PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef>
arg_list, Stmt body,
}
body = TensorToBufferMapper(std::move(extern_buffer))(std::move(body));
- return tir::PrimFunc(params, body, VoidType(), buffer_map);
+ // We mark this PrimFunc as coming from a TE schedule
+ return WithAttr(tir::PrimFunc(params, body, VoidType(), buffer_map),
"from_legacy_te_schedule",
+ Bool(true));
}
TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc")
diff --git a/src/tir/transforms/compact_buffer_region.cc
b/src/tir/transforms/compact_buffer_region.cc
index bd1fa9b..b1a4fd4 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -32,6 +32,7 @@
#include "../../support/arena.h"
#include "../../support/utils.h"
#include "../schedule/utils.h"
+#include "ir_utils.h"
namespace tvm {
namespace tir {
@@ -452,11 +453,16 @@ class BufferCompactor : public StmtExprMutator {
};
PrimFunc CompactBufferAllocation(PrimFunc f) {
- PrimFuncNode* fptr = f.CopyOnWrite();
- std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> region =
- BufferAccessRegionCollector::Collect(f);
- fptr->body = BufferCompactor::Compact(f, region);
- return f;
+ // Only apply this pass to TIR that is not from TE schedules
+ if (!IsFromLegacyTESchedule(f)) {
+ PrimFuncNode* fptr = f.CopyOnWrite();
+ std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> region =
+ BufferAccessRegionCollector::Collect(f);
+ fptr->body = BufferCompactor::Compact(f, region);
+ return f;
+ } else {
+ return f;
+ }
}
namespace transform {
diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc
b/src/tir/transforms/convert_blocks_to_opaque.cc
index 4c5e1dd..f7629d1 100644
--- a/src/tir/transforms/convert_blocks_to_opaque.cc
+++ b/src/tir/transforms/convert_blocks_to_opaque.cc
@@ -25,6 +25,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
+#include "ir_utils.h"
+
namespace tvm {
namespace tir {
@@ -83,9 +85,14 @@ class OpaqueBlockConverter : public StmtExprMutator {
};
PrimFunc ConvertBlocksToOpaque(PrimFunc f) {
- PrimFuncNode* fptr = f.CopyOnWrite();
- fptr->body = OpaqueBlockConverter::Substitute(f);
- return f;
+ // Only apply this pass to TIR that is not from TE schedules
+ if (!IsFromLegacyTESchedule(f)) {
+ PrimFuncNode* fptr = f.CopyOnWrite();
+ fptr->body = OpaqueBlockConverter::Substitute(f);
+ return f;
+ } else {
+ return f;
+ }
}
namespace transform {
diff --git a/src/tir/transforms/flatten_buffer.cc
b/src/tir/transforms/flatten_buffer.cc
index f1f914f..85c4123 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -28,6 +28,7 @@
#include <tvm/tir/transform.h>
#include "../../support/utils.h"
+#include "ir_utils.h"
namespace tvm {
namespace tir {
@@ -151,9 +152,14 @@ class BufferFlattener : public StmtExprMutator {
};
PrimFunc FlattenBuffer(PrimFunc f) {
- PrimFuncNode* fptr = f.CopyOnWrite();
- fptr->body = BufferFlattener::Flatten(f);
- return f;
+ // Only apply this pass to TIR that is not from TE schedules
+ if (!IsFromLegacyTESchedule(f)) {
+ PrimFuncNode* fptr = f.CopyOnWrite();
+ fptr->body = BufferFlattener::Flatten(f);
+ return f;
+ } else {
+ return f;
+ }
}
namespace transform {
diff --git a/src/tir/transforms/inject_prefetch.cc
b/src/tir/transforms/inject_prefetch.cc
index 4ce9c76..f20577e 100644
--- a/src/tir/transforms/inject_prefetch.cc
+++ b/src/tir/transforms/inject_prefetch.cc
@@ -31,6 +31,8 @@
#include <unordered_set>
+#include "ir_utils.h"
+
namespace tvm {
namespace tir {
@@ -96,9 +98,14 @@ namespace transform {
Pass InjectPrefetch() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
- auto* n = f.CopyOnWrite();
- n->body = PrefetchInjector()(std::move(n->body));
- return f;
+ // Only apply this pass to TIR from TE schedules
+ if (IsFromLegacyTESchedule(f)) {
+ auto* n = f.CopyOnWrite();
+ n->body = PrefetchInjector()(std::move(n->body));
+ return f;
+ } else {
+ return f;
+ }
};
return CreatePrimFuncPass(pass_func, 0, "tir.InjectPrefetch", {});
}
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index 7248bd4..a41905c 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -244,5 +244,10 @@ Region ConvertRegion(const MatchBufferRegion&
match_buffer, const Region& region
return result;
}
+Bool IsFromLegacyTESchedule(PrimFunc f) {
+ Optional<Bool> from_legacy_te_schedule =
f->GetAttr("from_legacy_te_schedule", Bool(false));
+ return from_legacy_te_schedule.value();
+}
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h
index 79c5f06..9be18b7 100644
--- a/src/tir/transforms/ir_utils.h
+++ b/src/tir/transforms/ir_utils.h
@@ -27,6 +27,7 @@
#include <tvm/runtime/device_api.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <limits>
@@ -213,6 +214,16 @@ Array<PrimExpr> ConvertIndices(const MatchBufferRegion&
match_buffer,
*/
Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region&
region);
+/*!
+ * \brief Check if a given PrimFunc originated from a TE schedule.
+ *
+ * Internally this checks for the `from_legacy_te_schedule` attr of the
PrimFunc.
+ *
+ * \param f PrimFunc to check
+ * \return Whether or not the PrimFunc was created from a te schedule
+ */
+Bool IsFromLegacyTESchedule(PrimFunc f);
+
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_
diff --git a/src/tir/transforms/lower_init_block.cc
b/src/tir/transforms/lower_init_block.cc
index c8aca51..d8621ac 100644
--- a/src/tir/transforms/lower_init_block.cc
+++ b/src/tir/transforms/lower_init_block.cc
@@ -25,6 +25,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
+#include "ir_utils.h"
+
namespace tvm {
namespace tir {
@@ -63,9 +65,14 @@ class InitBlockLower : public StmtMutator {
};
PrimFunc LowerInitBlock(PrimFunc func) {
- auto fptr = func.CopyOnWrite();
- fptr->body = InitBlockLower()(std::move(fptr->body));
- return func;
+ // Only apply this pass to TIR that is not from TE schedules
+ if (!IsFromLegacyTESchedule(func)) {
+ auto fptr = func.CopyOnWrite();
+ fptr->body = InitBlockLower()(std::move(fptr->body));
+ return func;
+ } else {
+ return func;
+ }
}
namespace transform {
diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc
b/src/tir/transforms/plan_update_buffer_allocation_location.cc
index 949c955..bee11ad 100644
--- a/src/tir/transforms/plan_update_buffer_allocation_location.cc
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -26,6 +26,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
+#include "ir_utils.h"
+
namespace tvm {
namespace tir {
@@ -145,10 +147,15 @@ class BufferAllocationLocator : public StmtExprMutator {
};
PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
- auto fptr = func.CopyOnWrite();
- BufferAllocationLocator locator(func);
- fptr->body = locator(fptr->body);
- return func;
+ // Only apply this pass to TIR that is not from TE schedules
+ if (!IsFromLegacyTESchedule(func)) {
+ auto fptr = func.CopyOnWrite();
+ BufferAllocationLocator locator(func);
+ fptr->body = locator(fptr->body);
+ return func;
+ } else {
+ return func;
+ }
}
namespace transform {
diff --git a/src/tir/transforms/storage_flatten.cc
b/src/tir/transforms/storage_flatten.cc
index 38b3a77..2c32cc7 100644
--- a/src/tir/transforms/storage_flatten.cc
+++ b/src/tir/transforms/storage_flatten.cc
@@ -500,13 +500,19 @@ class StorageFlattener : public StmtExprMutator {
};
PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool
create_bound_attributes) {
- auto fptr = func.CopyOnWrite();
-
- IRVisitorWithAnalyzer bound_analyzer;
- bound_analyzer(fptr->body);
- fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size,
create_bound_attributes,
- &bound_analyzer)(std::move(fptr->body));
- return func;
+ // Only apply this pass to TIR from TE schedules
+ Optional<Bool> from_legacy_te_schedule =
func->GetAttr("from_legacy_te_schedule", Bool(false));
+ if (from_legacy_te_schedule.value()) {
+ auto fptr = func.CopyOnWrite();
+
+ IRVisitorWithAnalyzer bound_analyzer;
+ bound_analyzer(fptr->body);
+ fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size,
create_bound_attributes,
+ &bound_analyzer)(std::move(fptr->body));
+ return func;
+ } else {
+ return func;
+ }
}
namespace transform {
diff --git a/tests/python/unittest/test_lower_build.py
b/tests/python/unittest/test_lower_build.py
index 4505a7b..e5528a8 100644
--- a/tests/python/unittest/test_lower_build.py
+++ b/tests/python/unittest/test_lower_build.py
@@ -52,6 +52,25 @@ def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
class LoweredModule:
def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
# function attr dict
+ tir.func_attr(
+ {"global_symbol": "main", "from_legacy_te_schedule": True,
"tir.noalias": True}
+ )
+ A = tir.match_buffer(a, [128, 128])
+ B = tir.match_buffer(b, [128, 128])
+ C = tir.match_buffer(c, [128, 128])
+ # body
+ for x, y in tir.grid(128, 128):
+ C.data[x * 128 + y] = 0.0
+ for k in tir.serial(0, 128):
+ C.data[x * 128 + y] = tir.load("float32", C.data, x * 128 + y)
+ tir.load(
+ "float32", A.data, x * 128 + k
+ ) * tir.load("float32", B.data, y * 128 + k)
+
+
[email protected]
+class LoweredTIRModule:
+ def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+ # function attr dict
tir.func_attr({"global_symbol": "main", "tir.noalias": True})
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
@@ -83,7 +102,7 @@ def test_lower_build_te_schedule():
def test_lower_build_tir_func():
# check lowering
ir_mod = tvm.lower(matmul)
- tvm.ir.assert_structural_equal(ir_mod, LoweredModule())
+ tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule())
# check building
mod = tvm.build(matmul, target="llvm")
_check_module_with_numpy(mod)
@@ -95,7 +114,7 @@ def test_lower_build_tir_module():
ir_mod = IRModule({"main": func})
# check lowering
lowered_mod = tvm.lower(ir_mod)
- tvm.ir.assert_structural_equal(lowered_mod, LoweredModule())
+ tvm.ir.assert_structural_equal(lowered_mod, LoweredTIRModule())
# check building
mod = tvm.build(ir_mod, target="llvm")
_check_module_with_numpy(mod)
@@ -103,8 +122,8 @@ def test_lower_build_tir_module():
def test_lower_build_lowered_module():
# check lowering
- ir_mod = tvm.lower(LoweredModule())
- tvm.ir.assert_structural_equal(ir_mod, LoweredModule())
+ ir_mod = tvm.lower(LoweredTIRModule())
+ tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule())
# check building
mod = tvm.build(ir_mod, target="llvm")
_check_module_with_numpy(mod)
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
index a469c6d..fb53b42 100644
--- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import tir
+from tvm import tir, te
from tvm.script import ty
@@ -371,6 +371,15 @@ def test_match_buffer():
_check(match_buffer_func, compacted_match_buffer_func)
+def test_lower_te():
+ x = te.placeholder((1,))
+ y = te.compute((1,), lambda i: x[i] + 2)
+ s = te.create_schedule(y.op)
+ orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
+ mod = tvm.tir.transform.CompactBufferAllocation()(orig_mod)
+ tvm.ir.assert_structural_equal(mod, orig_mod) # CompactBufferAllocation
should do nothing on TE
+
+
if __name__ == "__main__":
test_elementwise()
test_unschedulable_block()
diff --git
a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py
b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py
index 38fe1c9..708f1af 100644
--- a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py
+++ b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import tir
+from tvm import tir, te
from tvm.script import ty
@@ -73,5 +73,14 @@ def test_elementwise():
_check(elementwise_func, substituted_elementwise_func)
+def test_lower_te():
+ x = te.placeholder((1,))
+ y = te.compute((1,), lambda i: x[i] + 2)
+ s = te.create_schedule(y.op)
+ orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
+ mod = tvm.tir.transform.ConvertBlocksToOpaque()(orig_mod)
+ tvm.ir.assert_structural_equal(mod, orig_mod) # ConvertBlocksToOpaque
should do nothing on TE
+
+
if __name__ == "__main__":
test_elementwise()
diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py
b/tests/python/unittest/test_tir_transform_flatten_buffer.py
index 6929a32..3b2b3cf 100644
--- a/tests/python/unittest/test_tir_transform_flatten_buffer.py
+++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import tir
+from tvm import tir, te
from tvm.script import ty
@@ -234,6 +234,15 @@ def test_multi_alloc():
_check(compacted_multi_alloc_func, flattened_multi_alloc_func)
+def test_lower_te():
+ x = te.placeholder((1,))
+ y = te.compute((1,), lambda i: x[i] + 2)
+ s = te.create_schedule(y.op)
+ orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
+ mod = tvm.tir.transform.FlattenBuffer()(orig_mod)
+ tvm.ir.assert_structural_equal(mod, orig_mod) # FlattenBuffer should do
nothing on TE
+
+
if __name__ == "__main__":
test_elementwise()
test_gpu_workload()
diff --git a/tests/python/unittest/test_tir_transform_lower_init_block.py
b/tests/python/unittest/test_tir_transform_lower_init_block.py
index badf5e0..8499c93 100644
--- a/tests/python/unittest/test_tir_transform_lower_init_block.py
+++ b/tests/python/unittest/test_tir_transform_lower_init_block.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import tir
+from tvm import tir, te
from tvm.script import ty
# pylint: disable=no-self-argument
@@ -85,6 +85,15 @@ def test_lower_match_buffer():
tvm.ir.assert_structural_equal(mod, BranchWithMatchBuffer(), True)
+def test_lower_te():
+ x = te.placeholder((1,))
+ y = te.compute((1,), lambda i: x[i] + 2)
+ s = te.create_schedule(y.op)
+ orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
+ mod = tvm.tir.transform.LowerInitBlock()(orig_mod)
+ tvm.ir.assert_structural_equal(mod, orig_mod) # LowerInitBlock should do
nothing on TE
+
+
if __name__ == "__main__":
test_lower_reduction()
test_lower_match_buffer()
diff --git
a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
index 022c964..72a2f5e 100644
---
a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
+++
b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import tir
+from tvm import tir, te
from tvm.script import ty
@@ -149,6 +149,17 @@ def test_match_buffer_allocation():
_check(match_buffer_func, transformed_match_buffer_func)
+def test_lower_te():
+ x = te.placeholder((1,))
+ y = te.compute((1,), lambda i: x[i] + 2)
+ s = te.create_schedule(y.op)
+ orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
+ mod = tvm.tir.transform.PlanAndUpdateBufferAllocationLocation()(orig_mod)
+ tvm.ir.assert_structural_equal(
+ mod, orig_mod
+ ) # PlanAndUpdateBufferAllocationLocation should do nothing on TE
+
+
if __name__ == "__main__":
test_elementwise()
test_locate_buffer_allocation()
diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py
b/tests/python/unittest/test_tir_transform_storage_flatten.py
index 0e9ab86..1dd4a48 100644
--- a/tests/python/unittest/test_tir_transform_storage_flatten.py
+++ b/tests/python/unittest/test_tir_transform_storage_flatten.py
@@ -16,6 +16,8 @@
# under the License.
import tvm
from tvm import te
+from tvm.script import ty
+from tvm.relay import GlobalVar
def test_flatten2():
@@ -102,7 +104,9 @@ def test_flatten_double_buffer():
stmt = ib.get()
- mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C], stmt))
+ mod = tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([A, C], stmt).with_attr("from_legacy_te_schedule",
True)
+ )
with tvm.transform.PassContext(config={"tir.InjectDoubleBuffer":
{"split_loop": 2}}):
mod = tvm.transform.Sequential(
@@ -130,6 +134,21 @@ def test_flatten_double_buffer():
assert count[0] == 4
[email protected]
+def tir_func(a: ty.handle, b: ty.handle) -> None:
+ A = tir.match_buffer(a, [2, 2])
+ B = tir.match_buffer(a, [2, 2])
+ A[0, 1] = B[1, 1]
+
+
+def test_flatten_tir():
+ orig_mod = tvm.IRModule({GlobalVar("main"): tir_func})
+ mod = tvm.tir.transform.StorageFlatten(64)(orig_mod)
+ tvm.ir.assert_structural_equal(
+ orig_mod, mod
+ ) # StorageFlatten should do nothing to TIR functions
+
+
if __name__ == "__main__":
test_flatten2()
test_flatten_storage_align()