This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 3264895 [TIR][REFACTOR] Migrate low-level passes in tvm.lower to the
Unified IR pass manager. (#5364)
3264895 is described below
commit 326489505d67226c21e0b73a6aeef60d50f2cd6e
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Apr 18 12:33:58 2020 -0700
[TIR][REFACTOR] Migrate low-level passes in tvm.lower to the Unified IR
pass manager. (#5364)
- Migrate BoundCheckers and Simplify
- Migrate RewriteUnsafeSelect and RemoveNoOp
- Migrate UnrollLoop and StorageRewrite
- Migrate InjectDoubleBuffer and InjectVirtualThread
- Migrate LoopPartition and Vectorize
- Migrate CoProcSync, LiftAttrScope, InjectCopyIntrin
We still keep ir_pass registerations for now.
Need a separate PR to refactor the parts before the StorageFlatten.
---
include/tvm/tir/analysis.h | 1 -
include/tvm/tir/ir_pass.h | 140 ---------------
include/tvm/tir/transform.h | 118 ++++++++++++
python/tvm/driver/build_module.py | 4 +
python/tvm/tir/transform/transform.py | 197 +++++++++++++++++++++
src/arith/compute_expr.h | 1 +
src/driver/driver_api.cc | 132 +++++++-------
src/tir/pass/ffi_api.cc | 15 --
src/tir/{pass => transforms}/bound_checker.cc | 35 +++-
src/tir/{pass => transforms}/coproc_sync.cc | 26 ++-
src/tir/{pass => transforms}/inject_copy_intrin.cc | 24 ++-
.../{pass => transforms}/inject_double_buffer.cc | 25 ++-
.../{pass => transforms}/inject_virtual_thread.cc | 21 +++
src/tir/{pass => transforms}/lift_attr_scope.cc | 24 ++-
src/tir/{pass => transforms}/loop_partition.cc | 30 +++-
.../transforms/lower_device_storage_access_info.cc | 2 +
src/tir/transforms/narrow_datatype.cc | 4 +
src/tir/{pass => transforms}/remove_no_op.cc | 23 +++
.../{pass => transforms}/rewrite_unsafe_select.cc | 22 ++-
.../transforms/simplify.cc} | 24 ++-
src/tir/{pass => transforms}/storage_rewrite.cc | 25 ++-
src/tir/{pass => transforms}/unroll_loop.cc | 33 +++-
src/tir/{pass => transforms}/vectorize_loop.cc | 36 +++-
.../unittest/test_tir_pass_virtual_thread.py | 45 -----
...c_sync.py => test_tir_transform_coproc_sync.py} | 15 +-
...py => test_tir_transform_inject_copy_intrin.py} | 21 ++-
... => test_tir_transform_inject_double_buffer.py} | 14 +-
...=> test_tir_transform_inject_virtual_thread.py} | 26 ++-
...est_tir_transform_instrument_bound_checkers.py} | 39 +---
...pe.py => test_tir_transform_lift_attr_scope.py} | 10 +-
...ion.py => test_tir_transform_loop_partition.py} | 165 ++++++++++-------
...no_op.py => test_tir_transform_remove_no_op.py} | 14 +-
...=> test_tir_transform_rewrite_unsafe_select.py} | 16 +-
..._simplify.py => test_tir_transform_simplify.py} | 8 +-
...te.py => test_tir_transform_storage_rewrite.py} | 91 +++++++---
...unroll.py => test_tir_transform_unroll_loop.py} | 18 +-
...ectorize.py => test_tir_transform_vectorize.py} | 40 ++++-
37 files changed, 1026 insertions(+), 458 deletions(-)
diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index 6af9958..5c4990a 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -53,7 +53,6 @@ struct ExprDeepEqual {
TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
};
-
/*!
* \brief Find undefined vars in the statment.
* \param stmt The function to be checked.
diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h
index e228ce3..f3d447e 100644
--- a/include/tvm/tir/ir_pass.h
+++ b/include/tvm/tir/ir_pass.h
@@ -203,59 +203,6 @@ Stmt RewriteForTensorCore(Stmt stmt,
bool VerifyCompactBuffer(Stmt stmt);
/*!
- * \brief Remove No Op from the Stmt.
- * \param stmt The stmt to be trasnformed
- * \return Transformed stmt.
- */
-Stmt RemoveNoOp(Stmt stmt);
-
-/*!
- * \brief unroll the constant loop marked by unroll.
- * This pass also automatically attach pragma unroll tag to loops which meets
the standard.
- *
- * \param stmt The statment to be unrolled.
- * \param auto_max_step The maximum step before stop attach automatic unroll
- * \param auto_max_depth The maximum depth before stop attach automatic unroll
- * \param auto_max_extent The maximum extent of the loop we can unroll,
- * this is an legacy option that do not take the loop
total steps into account.
- * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll
annotation to codegen.
- * \return Transformed stmt.
- */
-Stmt UnrollLoop(Stmt stmt,
- int auto_max_step,
- int auto_max_depth,
- int auto_max_extent,
- bool explicit_unroll);
-
-/*!
- * \brief vectorize the constant loops
- * \param stmt The statement to be vectorized.
- * \return Transformed stmt.
- */
-Stmt VectorizeLoop(Stmt stmt);
-
-/*!
- * \brief convert vectorized loops into serialized loops
- * \param stmt The statement to skip vectorization on.
- * \return Transformed stmt.
- */
-Stmt SkipVectorize(Stmt stmt);
-
-/*!
-* \brief instruments bound checkers.
-* \param stmt The statement to be instrumented.
-* \return Instrumented stmt.
-*/
-Stmt InstrumentBoundCheckers(Stmt stmt);
-
-/*!
- * \brief Inject virtual thread loops into stmt.
- * \param stmt The statement to be transformed.
- * \return Transformed stmt.
- */
-Stmt InjectVirtualThread(Stmt stmt);
-
-/*!
* \brief Inject prefetch instructions into stmt.
* \param stmt The statement to be transformed.
* \return Transformed stmt.
@@ -263,84 +210,6 @@ Stmt InjectVirtualThread(Stmt stmt);
Stmt InjectPrefetch(Stmt stmt);
/*!
- * \brief Inject double buffer into stmt.
- * \param stmt The statement to be transformed.
- * \param split_loop Loop splitting factor.
- * \return Transformed stmt.
- */
-Stmt InjectDoubleBuffer(Stmt stmt, int split_loop);
-
-/*!
- * \brief Inject copy intrinsics with optional pad.
- *
- * \param stmt The statement to be transformed.
- * \param pragma_key The pragma key for hint of copy.
- * \param fintrin The function with signature
- *
- * Stmt fintrin(Buffer src,
- * Buffer dst,
- * Array<Expr> pad_before,
- * Array<Expr> pad_after,
- * Expr pad_value)
- * \return Transformed stmt.
- */
-Stmt InjectCopyIntrin(Stmt stmt,
- const std::string& pragma_key,
- const runtime::PackedFunc& fintrin);
-
-/*!
- * \brief Rewrite storage allocation pattern.
- * Moves the allocation to outer most possible scope.
- * Trying to share space between allocations to make
- * a static allocation plan when possible.
- *
- * \param stmt The stmt to be transformed
- * \return Transformed stmt.
- */
-Stmt StorageRewrite(Stmt stmt);
-
-/*!
- * \brief partition loops in the stmt
- * \param stmt The stmt to do loop partition
- * \param split_const_loop flag to enable partition for const loop
- * \return Transformed stmt.
- */
-Stmt LoopPartition(Stmt stmt, bool split_const_loop);
-
-/*!
- * \brief Detect and insert sync points to co-processor.
- *
- * \param stmt The stmt to be transformed
- * \return Transformed stmt.
- */
-Stmt CoProcSync(Stmt stmt);
-
-/*!
- * \brief Lift common attrs with attr_key to outer scope.
- *
- * \param stmt The stmt to be transformed
- * \param attr_key The attribute key to be checked.
- * \return Transformed stmt.
- */
-Stmt LiftAttrScope(Stmt stmt, std::string attr_key);
-
-/*!
- * \brief Detect and rewrite unsafe select that contains memory access.
- * \param stmt The statement to be rewritten.
- * \return Transformed stmt.
- */
-Stmt RewriteUnsafeSelect(Stmt stmt);
-
-/*!
- * \brief Lower attached storage access information.
- * Do this pass after all storage access analysis finish.
- *
- * \param stmt The stmt to be transformed
- * \return Transformed stmt.
- */
-Stmt LowerStorageAccessInfo(Stmt stmt);
-
-/*!
* \brief Decorate the stmt with a device scope, this is helpful for
* hardware accelerator without thread blocks.
*
@@ -357,15 +226,6 @@ Stmt DecorateDeviceScope(Stmt stmt);
Stmt HoistIfThenElse(Stmt stmt);
/*!
- * \brief Narrow down PrimExpr datatype in stmt to target_bits.
- * \note Run this pass after StorageFlatten.
- * \param stmt The stmt to do datatype rewrite
- * \param target_bits the bit of target datatype
- * \return Transformed stmt.
- */
-Stmt NarrowDataType(Stmt stmt, int target_bits);
-
-/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 23c1955..e593e1b 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -59,6 +59,124 @@ TVM_DLL Pass CreatePrimFuncPass(const
runtime::TypedPackedFunc<
const tvm::Array<runtime::String>& required);
/*!
+ * \brief Inject copy intrinsics with optional pad.
+ *
+ * \param pragma_key The pragma key for hint of copy.
+ * \param fintrin The function with signature
+ *
+ * Stmt fintrin(Buffer src,
+ * Buffer dst,
+ * Array<Expr> pad_before,
+ * Array<Expr> pad_after,
+ * Expr pad_value)
+ * \return The pass.
+ */
+TVM_DLL Pass InjectCopyIntrin(std::string pragma_key,
+ runtime::PackedFunc fintrin);
+
+/*!
+ * \brief Detect and insert sync points to co-processor.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass CoProcSync();
+
+/*!
+ * \brief Lift common attrs with attr_key to outer scope.
+ *
+ * \param attr_key The attribute key to be checked.
+ * \return The pass.
+ */
+TVM_DLL Pass LiftAttrScope(std::string attr_key);
+
+/*!
+ * \brief partition loops in the stmt.
+ *
+ * \param split_const_loop flag to enable partition for const loop
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass LoopPartition(bool split_const_loop);
+
+/*!
+ * \brief Lower vectorization loops.
+ *
+ * \param enable_vectorize Whether vectorization is enabled.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true);
+
+/*!
+ * \brief Inject virtual thread loops.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass InjectVirtualThread();
+
+/*!
+ * \brief Inject double buffer statements.
+ *
+ * \param split_loop_factor Loop splitting factor.
+ * \return The pass.
+ */
+TVM_DLL Pass InjectDoubleBuffer(int split_loop_factor);
+
+/*!
+ * \brief Rewrite storage allocation pattern.
+ * Moves the allocation to outer most possible scope.
+ * Trying to share space between allocations to make
+ * a static allocation plan when possible.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass StorageRewrite();
+
+/*!
+ * \brief unroll the constant loop marked by unroll.
+ * This pass also automatically attach pragma unroll tag to loops which meets
the standard.
+ *
+ * \param auto_max_step The maximum step before stop attach automatic unroll
+ * \param auto_max_depth The maximum depth before stop attach automatic unroll
+ * \param auto_max_extent The maximum extent of the loop we can unroll,
+ * this is an legacy option that do not take the loop total steps into
account.
+ * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll
annotation to codegen.
+ * \return The pass.
+ */
+TVM_DLL Pass UnrollLoop(int auto_max_step,
+ int auto_max_depth,
+ int auto_max_extent,
+ bool explicit_unroll);
+
+/*!
+ * \brief Remove No Op from the Stmt.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass RemoveNoOp();
+
+/*!
+ * \brief Detect and rewrite unsafe select that contains memory access.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass RewriteUnsafeSelect();
+
+/*!
+* \brief Run arithmetic simplifications on the statements and expressions.
+*
+* \return The pass.
+*/
+TVM_DLL Pass Simplify();
+
+/*!
+* \brief Instruments bound checkers.
+*
+* \return The pass.
+*/
+TVM_DLL Pass InstrumentBoundCheckers();
+
+/*!
* \brief Transform the high-level PrimFunc to a low-level version
* that can be used as an API function.
*
diff --git a/python/tvm/driver/build_module.py
b/python/tvm/driver/build_module.py
index a429d07..18a8a47 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -179,6 +179,7 @@ def lower(sch,
cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit)
+
for f in lower_phase2:
stmt = f(stmt)
@@ -187,11 +188,14 @@ def lower(sch,
stmt = ir_pass.RemoveNoOp(stmt)
if not cfg.disable_select_rewriting:
stmt = ir_pass.RewriteUnsafeSelect(stmt)
+
for f in lower_phase3:
stmt = f(stmt)
+
# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
stmt = ir_pass.InstrumentBoundCheckers(stmt)
+
if simple_mode:
return stmt
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index 9f64a93..f83bb11 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -60,6 +60,203 @@ def Filter(fcond):
return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter")
+def InjectCopyIntrin(pragma_key, fintrin):
+ """Inject virtual thread loops.
+
+ Parameters
+ ----------
+ pragma_key : str
+ The pragma key for hint of copy.
+
+ fintrin : function
+ The function with signature copyintrin(src, dst, pad_before,
pad_after, pad_value)
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InjectCopyIntrin(pragma_key, fintrin)
+
+
+def CoProcSync():
+ """Detect and insert sync points to co-processor.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.CoProcSync()
+
+
+def LiftAttrScope(attr_key):
+ """Lift common attrs with attr_key to outer scope.
+
+ Parameters
+ ----------
+ attr_key : str
+ The attribute key to be checked.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LiftAttrScope(attr_key)
+
+
+def LoopPartition(split_const_loop):
+ """Inject virtual thread loops.
+
+ Parameters
+ ----------
+ split_const_loop : bool
+ Flag to enable partition for const loop.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LoopPartition(split_const_loop)
+
+
+def VectorizeLoop(enable_vectorize=True):
+ """Lower vectorization loops.
+
+ Parameters
+ ----------
+ enable_vectorize : bool
+ Whether vectorization is enabled.
+ Will lower to scalar loop when it is turned off.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.VectorizeLoop(enable_vectorize)
+
+
+def InjectVirtualThread():
+ """Inject virtual thread loops.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InjectVirtualThread()
+
+
+def InjectDoubleBuffer(split_loop_factor):
+ """Inject double buffer statements.
+
+ Parameters
+ ----------
+ split_loop_factor : int
+ Loop splitting factor.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InjectDoubleBuffer(split_loop_factor)
+
+
+def StorageRewrite():
+ """Rewrite storage allocation pattern.
+
+ Moves the allocation to outer most possible scope.
+ Trying to share space between allocations to make
+ a static allocation plan when possible.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.StorageRewrite()
+
+
+def UnrollLoop(auto_max_step,
+ auto_max_depth,
+ auto_max_extent,
+ explicit_unroll):
+ """Unroll the constant loop marked by unroll.
+
+ This pass also automatically attach pragma unroll tag to loops which meets
the standard.
+
+ Parameters
+ ----------
+ auto_max_step : int
+ The maximum step before stop attach automatic unroll
+
+ auto_max_depth : int
+ The maximum depth before stop attach automatic unroll
+
+ auto_max_extent : int
+ The maximum extent of the loop we can unroll.
+ This is an legacy option that do not take the loop total steps into
account.
+
+ explicit_unroll : bool
+ Whether explicitly unroll the loop, or leave unroll annotation to
codegen.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.UnrollLoop(
+ auto_max_step, auto_max_depth, auto_max_extent, explicit_unroll)
+
+
+def RemoveNoOp():
+ """Remove No Op from the Stmt.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.RemoveNoOp()
+
+
+def RewriteUnsafeSelect():
+ """Detect and rewrite unsafe select that contains memory access.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.RewriteUnsafeSelect()
+
+
+def Simplify():
+ """Run arithmetic simplifications on the statements and expressions.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.Simplify()
+
+
+def InstrumentBoundCheckers():
+ """Instruments bound checkers.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InstrumentBoundCheckers()
+
+
def LowerCustomDatatypes():
"""Lower custom datatypes.
diff --git a/src/arith/compute_expr.h b/src/arith/compute_expr.h
index adb4f30..f842780 100644
--- a/src/arith/compute_expr.h
+++ b/src/arith/compute_expr.h
@@ -25,6 +25,7 @@
#define TVM_ARITH_COMPUTE_EXPR_H_
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <limits>
#include <algorithm>
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index f576c84..e38179e 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -109,64 +109,6 @@ void GetBinds(const Array<te::Tensor>& args,
}
}
-/*!
-* \brief Build a Stmt given a schedule, args and binds. This function runs the
IR passes.
-* \param sch The schedule to build.
-* \param args The arguments for the schedule.
-* \param binds Buffer assignments.
-* \param loop_partition True if the LoopPartition pass should be included.
-* \param out_arg_list Returns the arguments for the Stmt.
-* \param config The build configuration.
-* \return The built Stmt.
-*/
-tir::Stmt BuildStmt(te::Schedule sch,
- const Array<te::Tensor>& args,
- const std::unordered_map<te::Tensor, tir::Buffer>& binds,
- bool loop_partition,
- Array<ObjectRef> *out_arg_list,
- const BuildConfig& config) {
- sch = sch.normalize();
-
- // Phase 0
- auto bounds = te::InferBound(sch);
- auto stmt = te::ScheduleOps(sch, bounds, false);
- stmt = tir::InjectPrefetch(stmt);
-
- bool compact = tir::VerifyCompactBuffer(stmt);
- Map<te::Tensor, tir::Buffer> out_binds;
- GetBinds(args, compact, binds, &out_binds, out_arg_list, config);
-
- // Phase 1
- stmt = tir::StorageFlatten(stmt, out_binds, 64,
- config->instrument_bound_checkers);
- stmt = tir::CanonicalSimplify(stmt);
- if (loop_partition) {
- stmt = tir::LoopPartition(stmt, config->partition_const_loop);
- }
- if (config->disable_vectorize) {
- stmt = tir::SkipVectorize(stmt);
- } else {
- stmt = tir::VectorizeLoop(stmt);
- }
- stmt = tir::InjectVirtualThread(stmt);
- stmt = tir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop);
- stmt = tir::StorageRewrite(stmt);
- stmt = tir::UnrollLoop(stmt, config->auto_unroll_max_step,
config->auto_unroll_max_depth,
- config->auto_unroll_max_extent, config->unroll_explicit);
-
- // Phase 2
- stmt = tir::Simplify(stmt);
- stmt = tir::RemoveNoOp(stmt);
-
- if (!(config->disable_select_rewriting))
- stmt = tir::RewriteUnsafeSelect(stmt);
-
- if (config->instrument_bound_checkers)
- stmt = tir::InstrumentBoundCheckers(stmt);
-
- return stmt;
-}
-
transform::Pass BindTarget(Target target) {
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext
ctx) {
return WithAttr(std::move(f), tvm::attr::kTarget, target);
@@ -176,7 +118,7 @@ transform::Pass BindTarget(Target target) {
template<typename FCond>
-transform::Pass FilterBy(FCond fcond) {
+transform::Pass Filter(FCond fcond) {
auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext
ctx) {
if (fcond(f)) {
return f;
@@ -184,18 +126,14 @@ transform::Pass FilterBy(FCond fcond) {
return tir::PrimFunc(nullptr);
}
};
- return tir::transform::CreatePrimFuncPass(fpass, 0, "FilterBy", {});
+ return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
}
-IRModule lower(te::Schedule sch,
- const Array<te::Tensor>& args,
- const std::string& name,
- const std::unordered_map<te::Tensor, tir::Buffer>& binds,
- const BuildConfig& config) {
- Array<ObjectRef> out_arg_list;
- auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
-
+IRModule BuildIRModule(const Array<ObjectRef>& out_arg_list,
+ tir::Stmt stmt,
+ const std::string& name,
+ const BuildConfig& config) {
Array<tir::Var> params;
Map<tir::Var, tir::Buffer> buffer_map;
@@ -216,10 +154,64 @@ IRModule lower(te::Schedule sch,
if (config->restricted_func) {
f = WithAttr(std::move(f), "tir.noalias", Integer(1));
}
+
return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
}
+IRModule lower(te::Schedule sch,
+ const Array<te::Tensor>& args,
+ const std::string& name,
+ const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+ const BuildConfig& config) {
+ Array<ObjectRef> out_arg_list;
+
+ sch = sch.normalize();
+
+ // Phase 0
+ auto bounds = te::InferBound(sch);
+ auto stmt = te::ScheduleOps(sch, bounds, false);
+ stmt = tir::InjectPrefetch(stmt);
+
+ bool compact = tir::VerifyCompactBuffer(stmt);
+ Map<te::Tensor, tir::Buffer> out_binds;
+ GetBinds(args, compact, binds, &out_binds, &out_arg_list, config);
+
+ // Phase 1
+ stmt = tir::StorageFlatten(stmt, out_binds, 64,
+ config->instrument_bound_checkers);
+
+ // convert to IRModule.
+ auto mod = BuildIRModule(out_arg_list, stmt, name, config);
+ auto pass_list = Array<tvm::transform::Pass>();
+
+ pass_list.push_back(tir::transform::Simplify());
+
pass_list.push_back(tir::transform::LoopPartition(config->partition_const_loop));
+
pass_list.push_back(tir::transform::VectorizeLoop(!config->disable_vectorize));
+ pass_list.push_back(tir::transform::InjectVirtualThread());
+
pass_list.push_back(tir::transform::InjectDoubleBuffer(config->double_buffer_split_loop));
+ pass_list.push_back(tir::transform::StorageRewrite());
+ pass_list.push_back(
+ tir::transform::UnrollLoop(config->auto_unroll_max_step,
+ config->auto_unroll_max_depth,
+ config->auto_unroll_max_extent,
+ config->unroll_explicit));
+ // Phase 2
+ pass_list.push_back(tir::transform::Simplify());
+ pass_list.push_back(tir::transform::RemoveNoOp());
+ if (!(config->disable_select_rewriting)) {
+ pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+ }
+ if (config->instrument_bound_checkers) {
+ pass_list.push_back(tir::transform::InstrumentBoundCheckers());
+ }
+ // run
+ auto optimize = transform::Sequential(pass_list);
+ mod = optimize(std::move(mod));
+ return mod;
+}
+
+
std::pair<IRModule, IRModule>
split_dev_host_funcs(IRModule mod_mixed,
const Target& target,
@@ -242,7 +234,7 @@ split_dev_host_funcs(IRModule mod_mixed,
mod_mixed = opt_mixed(std::move(mod_mixed));
auto host_pass_list = {
- FilterBy([](const tir::PrimFunc& f) {
+ Filter([](const tir::PrimFunc& f) {
return f->GetAttr<Integer>(
tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch;
@@ -258,7 +250,7 @@ split_dev_host_funcs(IRModule mod_mixed,
// device pipeline
auto device_pass_list = {
- FilterBy([](const tir::PrimFunc& f) {
+ Filter([](const tir::PrimFunc& f) {
return f->GetAttr<Integer>(
tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch;
diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc
index 3083b68..65981b9 100644
--- a/src/tir/pass/ffi_api.cc
+++ b/src/tir/pass/ffi_api.cc
@@ -114,27 +114,12 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
REGISTER_PASS(ConvertSSA);
REGISTER_PASS(VerifySSA);
-REGISTER_PASS(RewriteUnsafeSelect);
REGISTER_PASS(Inline);
REGISTER_PASS(IRTransform);
-REGISTER_PASS(VectorizeLoop);
-REGISTER_PASS(SkipVectorize);
-REGISTER_PASS(UnrollLoop);
-REGISTER_PASS(InjectCopyIntrin);
-REGISTER_PASS(StorageRewrite);
-REGISTER_PASS(CoProcSync);
-REGISTER_PASS(LowerStorageAccessInfo);
-REGISTER_PASS(InjectVirtualThread);
REGISTER_PASS(InjectPrefetch);
-REGISTER_PASS(InjectDoubleBuffer);
-REGISTER_PASS(LoopPartition);
-REGISTER_PASS(RemoveNoOp);
-REGISTER_PASS(LiftAttrScope);
REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope);
-REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer);
REGISTER_PASS(HoistIfThenElse);
-REGISTER_PASS(NarrowDataType);
} // namespace tir
} // namespace tvm
diff --git a/src/tir/pass/bound_checker.cc b/src/tir/transforms/bound_checker.cc
similarity index 88%
rename from src/tir/pass/bound_checker.cc
rename to src/tir/transforms/bound_checker.cc
index ee24d0f..f770bc7 100644
--- a/src/tir/pass/bound_checker.cc
+++ b/src/tir/transforms/bound_checker.cc
@@ -22,8 +22,11 @@
*/
// Instrument checkers for out of the bounds access.
+#include <tvm/runtime/registry.h>
+#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <vector>
#include <unordered_map>
@@ -173,8 +176,8 @@ class BoundChecker : public StmtExprMutator {
}
// Try to simplify index and bound.
- index = tir::Simplify(index);
- upper_bound = tir::Simplify(upper_bound);
+ index = analyzer_.Simplify(index);
+ upper_bound = analyzer_.Simplify(upper_bound);
// Cast to the same type - signed, to be able to check lower bound.
index = CastNode::make(DataType::Int(64), index);
@@ -201,6 +204,8 @@ class BoundChecker : public StmtExprMutator {
const char *const error_message_ = "OUT OF THE BOUNDS";
// Hashtable which maps buffer_var to shape.
std::unordered_map<const VarNode *, PrimExpr> mem_to_shape_;
+ // internal analyzer
+ arith::Analyzer analyzer_;
};
Stmt InstrumentBoundCheckers(Stmt stmt) {
@@ -209,5 +214,29 @@ Stmt InstrumentBoundCheckers(Stmt stmt) {
bound_collector(stmt);
return BoundChecker(bound_collector.mem_to_shape)(std::move(stmt));
}
+
+
+TVM_REGISTER_GLOBAL("ir_pass.InstrumentBoundCheckers")
+.set_body_typed(InstrumentBoundCheckers);
+
+namespace transform {
+
+Pass InstrumentBoundCheckers() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ BoundCollector bound_collector;
+ // At first walk recursively and collect bound attributes.
+ bound_collector(n->body);
+ n->body = BoundChecker(bound_collector.mem_to_shape)(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.InstrumentBoundCheckers", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers")
+.set_body_typed(InstrumentBoundCheckers);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/pass/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc
similarity index 97%
rename from src/tir/pass/coproc_sync.cc
rename to src/tir/transforms/coproc_sync.cc
index 38b7798..fc20285 100644
--- a/src/tir/pass/coproc_sync.cc
+++ b/src/tir/transforms/coproc_sync.cc
@@ -20,13 +20,14 @@
/*!
* \file coproc_sync.cc
*/
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_map>
#include <unordered_set>
-#include "ir_util.h"
-#include "storage_access.h"
+#include "../pass/ir_util.h"
+#include "../pass/storage_access.h"
namespace tvm {
namespace tir {
@@ -677,5 +678,24 @@ Stmt CoProcSync(Stmt stmt) {
return CoProcSyncInserter().Insert(std::move(stmt));
}
+TVM_REGISTER_GLOBAL("ir_pass.CoProcSync")
+.set_body_typed(CoProcSync);
+
+namespace transform {
+
+Pass CoProcSync() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = CoProcSyncInserter().Insert(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.CoProcSync", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.CoProcSync")
+.set_body_typed(CoProcSync);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/pass/inject_copy_intrin.cc
b/src/tir/transforms/inject_copy_intrin.cc
similarity index 91%
rename from src/tir/pass/inject_copy_intrin.cc
rename to src/tir/transforms/inject_copy_intrin.cc
index 4805caf..5e40eb2 100644
--- a/src/tir/pass/inject_copy_intrin.cc
+++ b/src/tir/transforms/inject_copy_intrin.cc
@@ -21,10 +21,11 @@
* \brief Replace certain copy with copy intrinsics.
* \file copy_intrin_rewrite.cc
*/
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/transform.h>
#include <tvm/arith/pattern.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/ir_pass.h>
#include "../../arith/pattern_match.h"
namespace tvm {
@@ -196,5 +197,26 @@ Stmt InjectCopyIntrin(Stmt stmt,
return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt));
}
+TVM_REGISTER_GLOBAL("ir_pass.InjectCopyIntrin")
+.set_body_typed(InjectCopyIntrin);
+
+namespace transform {
+
+Pass InjectCopyIntrin(std::string pragma_key,
+ PackedFunc flower_copy_fromto) {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = CopyIntrinInjector(
+ pragma_key, flower_copy_fromto)(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.InjectCopyIntrin", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.InjectCopyIntrin")
+.set_body_typed(InjectCopyIntrin);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/pass/inject_double_buffer.cc
b/src/tir/transforms/inject_double_buffer.cc
similarity index 93%
rename from src/tir/pass/inject_double_buffer.cc
rename to src/tir/transforms/inject_double_buffer.cc
index b9aa5a9..e9422fa 100644
--- a/src/tir/pass/inject_double_buffer.cc
+++ b/src/tir/transforms/inject_double_buffer.cc
@@ -21,10 +21,12 @@
* \brief Inject double buffering optimization for data fetch.
* \file inject_double_buffer.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/op.h>
-#include "ir_util.h"
+#include "../pass/ir_util.h"
#include "../../arith/compute_expr.h"
namespace tvm {
@@ -273,5 +275,26 @@ class DoubleBufferInjector : public StmtExprMutator {
Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) {
return DoubleBufferInjector(split_loop).Inject(stmt);
}
+
+TVM_REGISTER_GLOBAL("ir_pass.InjectDoubleBuffer")
+.set_body_typed(InjectDoubleBuffer);
+
+
+namespace transform {
+
+Pass InjectDoubleBuffer(int split_loop) {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = DoubleBufferInjector(split_loop).Inject(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer")
+.set_body_typed(InjectDoubleBuffer);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/pass/inject_virtual_thread.cc
b/src/tir/transforms/inject_virtual_thread.cc
similarity index 96%
rename from src/tir/pass/inject_virtual_thread.cc
rename to src/tir/transforms/inject_virtual_thread.cc
index e9c403c..c70962d 100644
--- a/src/tir/pass/inject_virtual_thread.cc
+++ b/src/tir/transforms/inject_virtual_thread.cc
@@ -20,8 +20,10 @@
/*!
* \file inject_virtual_thread.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/ir_pass.h>
#include <unordered_set>
#include "../../arith/compute_expr.h"
@@ -500,5 +502,24 @@ Stmt InjectVirtualThread(Stmt stmt) {
return ConvertSSA(std::move(stmt));
}
+TVM_REGISTER_GLOBAL("ir_pass.InjectVirtualThread")
+.set_body_typed(InjectVirtualThread);
+
+namespace transform {
+
+Pass InjectVirtualThread() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = ConvertSSA(VirtualThreadInjector()(std::move(n->body)));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread")
+.set_body_typed(InjectVirtualThread);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/pass/lift_attr_scope.cc
b/src/tir/transforms/lift_attr_scope.cc
similarity index 90%
rename from src/tir/pass/lift_attr_scope.cc
rename to src/tir/transforms/lift_attr_scope.cc
index 9aa037f..a1d9223 100644
--- a/src/tir/pass/lift_attr_scope.cc
+++ b/src/tir/transforms/lift_attr_scope.cc
@@ -23,9 +23,10 @@
* the body contains the same scope.
* \file lift_attr_scope.cc
*/
-#include <tvm/tir/ir_pass.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
-#include "ir_util.h"
+#include "../pass/ir_util.h"
namespace tvm {
namespace tir {
@@ -191,5 +192,24 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key) {
return AttrScopeLifter(attr_key).Lift(std::move(stmt));
}
+TVM_REGISTER_GLOBAL("ir_pass.LiftAttrScope")
+.set_body_typed(LiftAttrScope);
+
+namespace transform {
+
+Pass LiftAttrScope(std::string attr_key) {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = AttrScopeLifter(attr_key).Lift(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.LiftAttrScope", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope")
+.set_body_typed(LiftAttrScope);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/pass/loop_partition.cc
b/src/tir/transforms/loop_partition.cc
similarity index 96%
rename from src/tir/pass/loop_partition.cc
rename to src/tir/transforms/loop_partition.cc
index e9157e7..dbed5f2 100644
--- a/src/tir/pass/loop_partition.cc
+++ b/src/tir/transforms/loop_partition.cc
@@ -20,9 +20,11 @@
/*!
* \file loop_partition.cc
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
+#include <tvm/tir/stmt_functor.h>
#include <tvm/arith/analyzer.h>
#include <unordered_map>
#include <unordered_set>
@@ -500,7 +502,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
Stmt pre_stmt;
bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) {
- body_begin = tir::Simplify(middle_interval.min());
+ body_begin = analyzer_.Simplify(middle_interval.min());
if (!analyzer_.CanProve(body_begin == min)) {
PrimExpr cond = (body_begin - min >= 0);
if (!analyzer_.CanProve(cond)) {
@@ -525,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node,
Stmt post_stmt;
bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) {
- post_doubt_begin = tir::Simplify(middle_interval.max() + 1);
+ post_doubt_begin = analyzer_.Simplify(middle_interval.max() + 1);
if (!analyzer_.CanProve(middle_interval.max() == max)) {
// require the extent to be non-negative
PrimExpr cond = (max - post_doubt_begin + 1 >= 0);
@@ -588,7 +590,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node,
PrimExpr extent, Stmt b
return Substitute(body, {{Var{for_node->loop_var},
make_const(DataType::Int(32), 0)}});
} else {
return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0),
extent,
- for_node->for_type, for_node->device_api, body);
+ for_node->for_type, for_node->device_api, body);
}
}
@@ -610,5 +612,25 @@ Stmt LoopPartition(Stmt stmt, bool split_const_loop) {
return stmt;
}
+
+TVM_REGISTER_GLOBAL("ir_pass.LoopPartition")
+.set_body_typed(LoopPartition);
+
+namespace transform {
+
+Pass LoopPartition(bool split_const_loop) {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = LoopPartition(std::move(n->body), split_const_loop);
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LoopPartition")
+.set_body_typed(LoopPartition);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/transforms/lower_device_storage_access_info.cc
b/src/tir/transforms/lower_device_storage_access_info.cc
index e7f81ed..9fa7230 100644
--- a/src/tir/transforms/lower_device_storage_access_info.cc
+++ b/src/tir/transforms/lower_device_storage_access_info.cc
@@ -143,6 +143,8 @@ Stmt LowerStorageAccessInfo(Stmt stmt) {
return StorageAccessInfoLower()(std::move(stmt));
}
+TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccessInfo")
+.set_body_typed(LowerStorageAccessInfo);
namespace transform {
diff --git a/src/tir/transforms/narrow_datatype.cc
b/src/tir/transforms/narrow_datatype.cc
index 1f9d976..4aeaafd 100644
--- a/src/tir/transforms/narrow_datatype.cc
+++ b/src/tir/transforms/narrow_datatype.cc
@@ -395,6 +395,10 @@ Stmt NarrowDataType(Stmt stmt, int target_bits) {
return DataTypeRewriter(target_bits)(stmt);
}
+TVM_REGISTER_GLOBAL("ir_pass.NarrowDataType")
+.set_body_typed(NarrowDataType);
+
+
namespace transform {
Pass NarrowDataType(int target_bits) {
diff --git a/src/tir/pass/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc
similarity index 89%
rename from src/tir/pass/remove_no_op.cc
rename to src/tir/transforms/remove_no_op.cc
index 181a8c4..44c974f 100644
--- a/src/tir/pass/remove_no_op.cc
+++ b/src/tir/transforms/remove_no_op.cc
@@ -21,8 +21,11 @@
* \file remove_no_op.cc
* \brief Remove no op from the stmt
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_map>
@@ -147,5 +150,25 @@ class NoOpRemover : public StmtMutator {
Stmt RemoveNoOp(Stmt stmt) {
return NoOpRemover()(std::move(stmt));
}
+
+TVM_REGISTER_GLOBAL("ir_pass.RemoveNoOp")
+.set_body_typed(RemoveNoOp);
+
+namespace transform {
+
+Pass RemoveNoOp() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = NoOpRemover()(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.RemoveNoOp")
+.set_body_typed(RemoveNoOp);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/pass/rewrite_unsafe_select.cc
b/src/tir/transforms/rewrite_unsafe_select.cc
similarity index 89%
rename from src/tir/pass/rewrite_unsafe_select.cc
rename to src/tir/transforms/rewrite_unsafe_select.cc
index 5016492..386b4cc 100644
--- a/src/tir/pass/rewrite_unsafe_select.cc
+++ b/src/tir/transforms/rewrite_unsafe_select.cc
@@ -21,9 +21,10 @@
* \file unsafe_select_rewrite.cc
* \brief Rewrite uinsafe select expression.
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
namespace tvm {
namespace tir {
@@ -132,5 +133,24 @@ Stmt RewriteUnsafeSelect(Stmt stmt) {
return UnsafeSelectRewriter()(std::move(stmt));
}
+TVM_REGISTER_GLOBAL("ir_pass.RewriteUnsafeSelect")
+.set_body_typed(RewriteUnsafeSelect);
+
+namespace transform {
+
+Pass RewriteUnsafeSelect() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = UnsafeSelectRewriter()(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect")
+.set_body_typed(RewriteUnsafeSelect);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
diff --git a/src/arith/stmt_simplify.cc b/src/tir/transforms/simplify.cc
similarity index 86%
rename from src/arith/stmt_simplify.cc
rename to src/tir/transforms/simplify.cc
index 6c3dd02..ecfa25e 100644
--- a/src/arith/stmt_simplify.cc
+++ b/src/tir/transforms/simplify.cc
@@ -18,17 +18,19 @@
*/
/*!
- * \file stmt_simplify.cc
+ * \file simplify.cc
* \brief Statement simplifier based on analyzer
*/
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
#include <tvm/arith/analyzer.h>
-#include "ir_mutator_with_analyzer.h"
+#include "../../arith/ir_mutator_with_analyzer.h"
namespace tvm {
namespace arith {
@@ -125,5 +127,23 @@ PrimExpr Simplify(PrimExpr expr, Map<Var, Range> vrange) {
Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
return CanonicalSimplify(std::move(stmt), vrange);
}
+
+namespace transform {
+
+Pass Simplify() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ arith::Analyzer analyzer;
+ n->body = arith::StmtSimplifier(&analyzer).Simplify(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.Simplify")
+.set_body_typed(Simplify);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/pass/storage_rewrite.cc
b/src/tir/transforms/storage_rewrite.cc
similarity index 98%
rename from src/tir/pass/storage_rewrite.cc
rename to src/tir/transforms/storage_rewrite.cc
index f3604b6..c13879c 100644
--- a/src/tir/pass/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -22,16 +22,18 @@
* \brief Memory access pattern analysis and optimization.
* Re-write data access to enable memory sharing when possible.
*/
+#include <tvm/runtime/registry.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/target/target_info.h>
#include <map>
#include <unordered_set>
#include <unordered_map>
-#include "ir_util.h"
+#include "../pass/ir_util.h"
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
@@ -1039,5 +1041,26 @@ Stmt StorageRewrite(Stmt stmt) {
stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true);
return VectorAllocRewriter()(std::move(stmt));
}
+
+TVM_REGISTER_GLOBAL("ir_pass.StorageRewrite")
+.set_body_typed(StorageRewrite);
+
+namespace transform {
+
+Pass StorageRewrite() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true);
+ n->body = VectorAllocRewriter()(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite")
+.set_body_typed(StorageRewrite);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/pass/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc
similarity index 88%
rename from src/tir/pass/unroll_loop.cc
rename to src/tir/transforms/unroll_loop.cc
index 0167dbc..27c39d4 100644
--- a/src/tir/pass/unroll_loop.cc
+++ b/src/tir/transforms/unroll_loop.cc
@@ -22,8 +22,11 @@
* \file unroll_loop.cc
*/
// Unrolls the loop as in Halide pipeline.
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_set>
#include <unordered_map>
@@ -201,13 +204,31 @@ Stmt UnrollLoop(Stmt stmt,
}
}
-Stmt UnrollLoopExplicitly(Stmt stmt) {
- const ForNode* op = stmt.as<ForNode>();
- if (!op) {
- LOG(FATAL) << "attempted to unroll a non-loop statement";
- }
- return LoopUnroller(0, 0, 0, false).Unroll(op);
+TVM_REGISTER_GLOBAL("ir_pass.UnrollLoop")
+.set_body_typed(UnrollLoop);
+
+namespace transform {
+
+Pass UnrollLoop(int auto_max_step,
+ int auto_max_depth,
+ int auto_max_extent,
+ bool explicit_unroll) {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = UnrollLoop(std::move(f->body),
+ auto_max_step,
+ auto_max_depth,
+ auto_max_extent,
+ explicit_unroll);
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {});
}
+TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop")
+.set_body_typed(UnrollLoop);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/pass/vectorize_loop.cc
b/src/tir/transforms/vectorize_loop.cc
similarity index 95%
rename from src/tir/pass/vectorize_loop.cc
rename to src/tir/transforms/vectorize_loop.cc
index b73587d..cc4361d 100644
--- a/src/tir/pass/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -21,9 +21,11 @@
* \file vectorize_loop.cc
*/
// Loop vectorizer as in Halide pipeline.
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/ir_pass.h>
#include <tvm/arith/analyzer.h>
#include <unordered_set>
#include <unordered_map>
@@ -539,8 +541,9 @@ class VectorizeSkipper : public StmtMutator {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
if (op->for_type == ForType::Vectorized) {
- return ForNode::make(op->loop_var, op->min, op->extent, ForType::Serial,
op->device_api,
- op->body);
+ return ForNode::make(op->loop_var, op->min, op->extent,
+ ForType::Serial, op->device_api,
+ op->body);
} else {
return stmt;
}
@@ -551,5 +554,32 @@ Stmt SkipVectorize(Stmt stmt) {
return VectorizeSkipper()(std::move(stmt));
}
+TVM_REGISTER_GLOBAL("ir_pass.VectorizeLoop")
+.set_body_typed(VectorizeLoop);
+
+TVM_REGISTER_GLOBAL("ir_pass.SkipVectorize")
+.set_body_typed(SkipVectorize);
+
+namespace transform {
+
+// TODO(tvm-team): Make it as a target property.
+Pass VectorizeLoop(bool enable_vectorize) {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ if (enable_vectorize) {
+ n->body = LoopVectorizer()(std::move(n->body));
+ } else {
+ n->body = VectorizeSkipper()(std::move(n->body));
+ }
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop")
+.set_body_typed(VectorizeLoop);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
diff --git a/tests/python/unittest/test_tir_pass_virtual_thread.py
b/tests/python/unittest/test_tir_pass_virtual_thread.py
deleted file mode 100644
index 2d96696..0000000
--- a/tests/python/unittest/test_tir_pass_virtual_thread.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import tvm
-from tvm import te
-
-def test_virtual_thread():
- m = te.var('m')
- A = te.placeholder((m, ), name='A')
- A1 = te.compute((m,), lambda i: A[i], name='A1')
- A2 = te.compute((m,), lambda i: A1[i] + 3, name='A2')
-
- s = te.create_schedule(A2.op)
- vx = te.thread_axis("vthread", name="vx")
- xo, xi = s[A2].split(A2.op.axis[0], nparts=2)
- s[A2].bind(xo, vx)
- xo, xi = s[A2].split(xi, 8)
- s[A1].compute_at(s[A2], xo)
-
- bounds = tvm.te.schedule.InferBound(s)
- assert isinstance(bounds, tvm.container.Map)
- stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
- Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
- A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- stmt = tvm.tir.ir_pass.InjectVirtualThread(stmt)
- print(stmt)
-
-if __name__ == "__main__":
- test_virtual_thread()
diff --git a/tests/python/unittest/test_tir_pass_coproc_sync.py
b/tests/python/unittest/test_tir_transform_coproc_sync.py
similarity index 91%
rename from tests/python/unittest/test_tir_pass_coproc_sync.py
rename to tests/python/unittest/test_tir_transform_coproc_sync.py
index b0e2050..f658349 100644
--- a/tests/python/unittest/test_tir_pass_coproc_sync.py
+++ b/tests/python/unittest/test_tir_transform_coproc_sync.py
@@ -37,7 +37,10 @@ def test_coproc_sync():
ib.scope_attr(cp, "coproc_scope", 1)
A[j] = A[j + k * 10] + 2
stmt = ib.get()
- stmt = tvm.tir.ir_pass.CoProcSync(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
+ stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body
+
body = stmt.body.body.body
blist = tvm.tir.stmt_list(body)
assert(blist[1].value.name == "cop.coproc_read_barrier")
@@ -65,7 +68,10 @@ def test_coproc_sync2():
ib.scope_attr(cp, "coproc_scope", 2)
A[ty] = 1.0
stmt = ib.get()
- stmt = tvm.tir.ir_pass.CoProcSync(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
+ stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body
+
def test_coproc_sync3():
def __check_list(tvm_array, py_list):
@@ -91,7 +97,10 @@ def test_coproc_sync3():
A[0] = 0.0
stmt = ib.get()
- stmt = tvm.tir.ir_pass.CoProcSync(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
+ stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body
+
slist = tvm.tir.stmt_list(stmt[0].body.body)
push_st = slist[2]
slist = tvm.tir.stmt_list(slist[-1])
diff --git a/tests/python/unittest/test_tir_pass_inject_copy_intrin.py
b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py
similarity index 89%
rename from tests/python/unittest/test_tir_pass_inject_copy_intrin.py
rename to tests/python/unittest/test_tir_transform_inject_copy_intrin.py
index 8c34e34..7ec2e48 100644
--- a/tests/python/unittest/test_tir_pass_inject_copy_intrin.py
+++ b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py
@@ -35,7 +35,10 @@ def test_copy2d():
assert src.strides[0] == l
assert tuple(src.shape) == (m, l)
return tvm.tir.Evaluate(0)
- stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
+
def test_copy_pad():
m = te.var('m')
@@ -59,7 +62,10 @@ def test_copy_pad():
assert pad_after[1].value == 0
assert pad_value.value == 1.0
return tvm.tir.Evaluate(0)
- stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
+
def test_single_point_test():
A = te.placeholder((1,), name='A')
@@ -78,7 +84,10 @@ def test_single_point_test():
assert tvm.tir.ir_pass.Simplify(src.strides[0]).value == 1
assert tvm.tir.ir_pass.Simplify(dst.strides[0]).value == 1
return tvm.tir.Evaluate(0)
- stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
+
def assert_expr_equal(a, b):
assert tvm.tir.ir_pass.Simplify(a - b).value == 0
@@ -111,7 +120,11 @@ def test_copy_pad_split():
assert_expr_equal(pad_after[0], rpad_after)
assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
return tvm.tir.Evaluate(0)
- stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
+
+
if __name__ == "__main__":
diff --git a/tests/python/unittest/test_tir_pass_inject_double_buffer.py
b/tests/python/unittest/test_tir_transform_inject_double_buffer.py
similarity index 91%
rename from tests/python/unittest/test_tir_pass_inject_double_buffer.py
rename to tests/python/unittest/test_tir_transform_inject_double_buffer.py
index 6b04db3..4c0573d 100644
--- a/tests/python/unittest/test_tir_pass_inject_double_buffer.py
+++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py
@@ -36,13 +36,19 @@ def test_double_buffer():
C[j] = B[j] + 1
stmt = ib.get()
- stmt = tvm.tir.ir_pass.InjectDoubleBuffer(stmt, 2)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- assert isinstance(stmt.body.body, tvm.tir.Allocate)
- assert stmt.body.body.extents[0].value == 2
mod = tvm.IRModule({
"db" : tvm.tir.PrimFunc([A.asobject(), C.asobject()], stmt)
})
+
+ opt = tvm.transform.Sequential(
+ [tvm.tir.transform.InjectDoubleBuffer(2),
+ tvm.tir.transform.Simplify()])
+ mod = opt(mod)
+ stmt = mod["db"].body
+
+ assert isinstance(stmt.body.body, tvm.tir.Allocate)
+ assert stmt.body.body.extents[0].value == 2
+
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0]
def count_sync(op):
diff --git a/tests/python/unittest/test_tir_pass_inject_vthread.py
b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
similarity index 83%
rename from tests/python/unittest/test_tir_pass_inject_vthread.py
rename to tests/python/unittest/test_tir_transform_inject_virtual_thread.py
index 8fbd829..c0789c6 100644
--- a/tests/python/unittest/test_tir_pass_inject_vthread.py
+++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
@@ -40,9 +40,14 @@ def test_vthread():
C[i * nthread + tx] = B[i] + 1
return ib.get()
- stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("vthread"))
+ stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([], get_vthread("vthread"))))["main"].body
+
assert stmt.body.body.extents[0].value == 2
- stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("cthread"))
+
+ stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([], get_vthread("cthread"))))["main"].body
+
assert len(stmt.body.body.extents) == 3
@@ -67,16 +72,20 @@ def test_vthread_extern():
A[tx] = tx + 1.0
B[ty] = ty + 1.0
ib.emit(tvm.tir.call_extern("int32", "Run",
- abuffer.access_ptr("r"),
- bbuffer.access_ptr("r"),
- cbuffer.access_ptr("rw")))
+ abuffer.access_ptr("r"),
+ bbuffer.access_ptr("r"),
+ cbuffer.access_ptr("rw")))
return ib.get()
- stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("vthread"))
+
+ stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([], get_vthread("cthread"))))["main"].body
+
assert stmt.body.body.extents[0].value == 2
assert stmt.body.body.body.body.body.body.extents[0].value == 2
assert len(stmt.body.body.body.body.body.body.extents) == 3
+
def test_vthread_if_then_else():
nthread = 2
tx = te.thread_axis("vthread")
@@ -92,7 +101,10 @@ def test_vthread_if_then_else():
with ib.if_scope(i == 0):
B[i] = A[i * nthread + tx] + 2
stmt = ib.get()
- stmt = tvm.tir.ir_pass.InjectVirtualThread(stmt)
+
+ stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([], stmt)))["main"].body
+
assert stmt.body.body.body[0].else_case != None
assert stmt.body.body.body[1].else_case == None
diff --git a/tests/python/unittest/test_tir_pass_bound_checkers.py
b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py
similarity index 94%
rename from tests/python/unittest/test_tir_pass_bound_checkers.py
rename to tests/python/unittest/test_tir_transform_instrument_bound_checkers.py
index d6c89b2..47c1f7b 100644
--- a/tests/python/unittest/test_tir_pass_bound_checkers.py
+++ b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py
@@ -18,32 +18,12 @@ import pytest
import tvm
from tvm import te
import numpy as np
+
def collect_visit(stmt, f):
ret = []
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
return ret
-def lower(sch, args):
- binds = {}
- arg_list = []
- for x in args:
- if isinstance(x, te.tensor.Tensor):
- buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
- assert x not in binds
- binds[x] = buf
- arg_list.append(buf)
- else:
- raise ValueError("args must be Tensor, Buffer or Var")
- sch = sch.normalize()
- bounds = tvm.te.schedule.InferBound(sch)
- stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.RemoveNoOp(stmt)
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, True)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- return stmt
@pytest.mark.xfail
def test_out_of_bounds_llvm(index_a, index_b):
@@ -72,7 +52,6 @@ def test_in_bounds_llvm():
tgt = "llvm"
tgt_host = "llvm"
stmt = tvm.lower (s, [A, B, C], simple_mode=True)
- print (stmt)
fadd = tvm.build (s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
ctx = tvm.context(tgt, 0)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
@@ -93,7 +72,6 @@ def test_out_of_bounds_vectorize_llvm(nn, index_a, index_b):
tgt = "llvm"
tgt_host = "llvm"
stmt = tvm.lower (s, [a, b, c], simple_mode=True)
- print (stmt)
f = tvm.build(s, [a, b, c], tgt, target_host=tgt_host, name="myaddvec")
ctx = tvm.cpu(0)
n = nn
@@ -192,13 +170,11 @@ def test_in_bounds_const_loop_partition_ir():
s = te.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
- bounds = tvm.te.schedule.InferBound(s)
- stmt = lower (s, [A, B, T])
- # num_attributes = num_buffers * num_splits = 2 * 3
- # before instrumentation
- assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3)
- assert_bound_instrumentation(stmt, check_branch_stmt, 0)
- stmt = tvm.tir.ir_pass.InstrumentBoundCheckers(stmt)
+ with tvm.target.build_config(instrument_bound_checkers=True,
+ partition_const_loop=True):
+ mod = tvm.driver.lower(s, [A, B, T], name="main")
+
+ stmt = mod["main"].body
# after instrumentation
assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3)
assert_bound_instrumentation(stmt, check_branch_stmt, 2)
@@ -209,7 +185,8 @@ def test_in_bounds_const_loop_partition_ir():
def test_in_bounds_const_loop_partition_llvm():
- with tvm.target.build_config(instrument_bound_checkers=True,
partition_const_loop=True):
+ with tvm.target.build_config(instrument_bound_checkers=True,
+ partition_const_loop=True):
n = 21
A = te.placeholder((n, ), name='A')
B = te.placeholder((n, ), name='B')
diff --git a/tests/python/unittest/test_tir_pass_lift_attr_scope.py
b/tests/python/unittest/test_tir_transform_lift_attr_scope.py
similarity index 88%
rename from tests/python/unittest/test_tir_pass_lift_attr_scope.py
rename to tests/python/unittest/test_tir_transform_lift_attr_scope.py
index 0831565..f5f4030 100644
--- a/tests/python/unittest/test_tir_pass_lift_attr_scope.py
+++ b/tests/python/unittest/test_tir_transform_lift_attr_scope.py
@@ -35,7 +35,10 @@ def test_coproc_lift():
A[j] = A[j] + 3
A[j] = A[j] + 3
body = ib.get()
- body = tvm.tir.ir_pass.LiftAttrScope(body, "coproc_uop_scope")
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+ body =
tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body
+
assert body.body.body.node == cp
# only able to lift to the common pattern of the last two fors.
@@ -52,7 +55,10 @@ def test_coproc_lift():
A[i] = A[i] + 2
body = ib.get()
- body = tvm.tir.ir_pass.LiftAttrScope(body, "coproc_uop_scope")
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+ body =
tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body
+
assert body.body.body.body[1].node == cp
assert len(body.body.body.body) == 2
diff --git a/tests/python/unittest/test_tir_pass_loop_partition.py
b/tests/python/unittest/test_tir_transform_loop_partition.py
similarity index 79%
rename from tests/python/unittest/test_tir_pass_loop_partition.py
rename to tests/python/unittest/test_tir_transform_loop_partition.py
index 1256d8b..6ca3f59 100644
--- a/tests/python/unittest/test_tir_pass_loop_partition.py
+++ b/tests/python/unittest/test_tir_transform_loop_partition.py
@@ -23,26 +23,6 @@ def collect_visit(stmt, f):
tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x)))
return ret
-def lower(sch, args):
- binds = {}
- arg_list = []
- for x in args:
- if isinstance(x, te.tensor.Tensor):
- buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
- assert x not in binds
- binds[x] = buf
- arg_list.append(buf)
- else:
- raise ValueError("args must be Tensor, Buffer or Var")
- sch = sch.normalize()
- bounds = tvm.te.schedule.InferBound(sch)
- stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- return stmt
def test_basic():
n = te.size_var('n')
@@ -55,10 +35,16 @@ def test_basic():
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- assert('if' not in str(stmt.body.body[0]))
- assert('if' in str(stmt.body.body[1]))
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt))
+ mod = tvm.tir.transform.LoopPartition(False)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+ assert(not any(
+ collect_visit(stmt.body.body[0], lambda x: isinstance(x,
tvm.tir.IfThenElse))))
+ assert(any(
+ collect_visit(stmt.body.body[1], lambda x: isinstance(x,
tvm.tir.IfThenElse))))
+
def test_const_loop():
n = 21
@@ -71,9 +57,12 @@ def test_const_loop():
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- assert('if' not in str(stmt.body.body[0]))
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+ assert(not any(collect_visit(stmt, lambda x: isinstance(x,
tvm.tir.IfThenElse))))
def test_multi_loop():
ib = tvm.tir.ir_builder.create()
@@ -87,8 +76,11 @@ def test_multi_loop():
with ib.else_scope():
ib.emit(tvm.tir.Evaluate(n))
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n, m], stmt))
+ mod = tvm.tir.transform.LoopPartition(False)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt.body[0], lambda x: isinstance(x,
tvm.tir.IfThenElse))))
def test_multi_if():
@@ -107,9 +99,14 @@ def test_multi_if():
with ib.else_scope():
ib.emit(tvm.tir.Evaluate(n))
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- assert('if' not in str(stmt.body[0]))
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ mod = tvm.tir.transform.LoopPartition(False)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+ assert(not any(
+ collect_visit(stmt.body[0], lambda x: isinstance(x,
tvm.tir.IfThenElse))))
+
def test_thread_axis():
m = te.size_var('m')
@@ -126,9 +123,14 @@ def test_thread_axis():
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- assert('if' not in str(stmt.body.body[0]))
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ mod = tvm.tir.transform.LoopPartition(False)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+ assert(not any(
+ collect_visit(stmt.body. body[0], lambda x: isinstance(x,
tvm.tir.IfThenElse))))
+
def test_vectorize():
n = te.size_var('n')
@@ -147,11 +149,12 @@ def test_vectorize():
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
s[C].vectorize(x)
- stmt = lower(s, [A, B])
+ stmt = tvm.lower(s, [A, B], name="main")["main"].body
body = stmt.body.body.body.body
assert(x.var.name not in str(body.condition))
assert(any(collect_visit(body.then_case, lambda x: isinstance(x,
tvm.tir.Ramp))))
+
def test_condition():
ib = tvm.tir.ir_builder.create()
m = te.size_var('m')
@@ -161,10 +164,14 @@ def test_condition():
ib.emit(tvm.tir.Evaluate(
tvm.tir.Select(ib.likely(i*4+j<n), m, n)))
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
+ mod = tvm.tir.transform.LoopPartition(False)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt[0], lambda x: isinstance(x,
tvm.tir.Select))))
+
def test_condition_EQ():
ib = tvm.tir.ir_builder.create()
m = te.size_var('m')
@@ -173,10 +180,14 @@ def test_condition_EQ():
ib.emit(tvm.tir.Evaluate(
tvm.tir.Select(ib.likely(tvm.tir.EQ(i, 5)), m, n)))
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt[0], lambda x: isinstance(x,
tvm.tir.Select))))
+
def test_thread_axis2():
n = tvm.runtime.convert(4096)
m = te.size_var('m')
@@ -190,7 +201,7 @@ def test_thread_axis2():
_, x = s[C].split(x, factor=m)
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
- stmt = lower(s, [A, B])
+ stmt = tvm.lower(s, [A, B], name="main")["main"].body
for_body = stmt.body.body.body.body[0]
assert('threadIdx' not in str(for_body.extent))
@@ -204,8 +215,12 @@ def test_everything_during_deduction():
# this guard will produce everything during deduction
ib.emit(tvm.tir.Evaluate(m))
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt))
+ mod = tvm.tir.transform.LoopPartition(False)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
+
assert(isinstance(stmt.body.body, tvm.tir.IfThenElse))
def test_single_likely():
@@ -220,8 +235,11 @@ def test_single_likely():
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt, lambda x: isinstance(x,
tvm.tir.IfThenElse))))
def test_multi_likely():
@@ -241,10 +259,14 @@ def test_multi_likely():
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt, lambda x: isinstance(x,
tvm.tir.IfThenElse))))
+
def test_oneD_pool():
m = te.size_var('m')
ib = tvm.tir.ir_builder.create()
@@ -268,10 +290,14 @@ def test_oneD_pool():
out[ow] = tvm.te.max(out[ow], data[ow + kw - 1])
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, data, out], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt, lambda x: isinstance(x,
tvm.tir.IfThenElse))))
+
def test_cce_loop_1():
ib = tvm.tir.ir_builder.create()
dtype = 'float16'
@@ -289,8 +315,11 @@ def test_cce_loop_1():
with ib.if_scope(ib.likely(((i*160) + j) < 1600)):
A[(i+1)*m+j+1] = B[(i)*m+j+1] + B[(i+1)*m+j+1] + B[(i+2)*m+j+1]
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt, lambda x: isinstance(x,
tvm.tir.IfThenElse))))
def test_cce_loop_2():
@@ -308,8 +337,12 @@ def test_cce_loop_2():
ib.emit(tvm.tir.call_extern('float32', "cce_intrisic", head, tail))
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt, lambda x: isinstance(x,
tvm.tir.IfThenElse))))
@@ -326,10 +359,14 @@ def test_cce_loop_3():
ib.emit(tvm.tir.call_extern('float16',"cce_intrisic",head1))
stmt = ib.get()
- stmt = tvm.tir.ir_pass.LoopPartition(stmt,True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt, lambda x: isinstance(x,
tvm.tir.IfThenElse))))
+
def test_conv_tiling():
HSTR = WSTR = 1
in_channel = 128
@@ -355,8 +392,11 @@ def test_conv_tiling():
oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, True)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ mod = tvm.tir.transform.LoopPartition(True)(mod)
+ stmt = tvm.tir.transform.Simplify()(mod)["main"].body
+
assert(not any(collect_visit(stmt, lambda x: isinstance(x,
tvm.tir.IfThenElse))))
@@ -426,14 +466,15 @@ def test_simple_rfactor():
s.normalize()
bounds = tvm.te.schedule.InferBound(s)
-
stmt1 = tvm.te.schedule.ScheduleOps(s, bounds)
- stmt1 = tvm.tir.ir_pass.Simplify(stmt1)
- stmt2 = tvm.tir.ir_pass.LoopPartition(stmt1, True)
- stmt2 = tvm.tir.ir_pass.Simplify(stmt2)
+ mod1 = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt1))
+ stmt1 = tvm.tir.transform.Simplify()(mod1)["main"].body
+
+ mod2 = tvm.tir.transform.LoopPartition(True)(mod1)
+ stmt2 = tvm.tir.transform.Simplify()(mod2)["main"].body
- #make sure loop partition actually did something
+ # make sure loop partition actually did something
assert not tvm.ir.structural_equal(stmt1.body, stmt2.body)
diff --git a/tests/python/unittest/test_tir_pass_remove_no_op.py
b/tests/python/unittest/test_tir_transform_remove_no_op.py
similarity index 81%
rename from tests/python/unittest/test_tir_pass_remove_no_op.py
rename to tests/python/unittest/test_tir_transform_remove_no_op.py
index c9ecfbe..c58b8b4 100644
--- a/tests/python/unittest/test_tir_pass_remove_no_op.py
+++ b/tests/python/unittest/test_tir_transform_remove_no_op.py
@@ -36,16 +36,24 @@ def test_remove_no_op():
k, 0, m, 0, 0,
tvm.tir.IfThenElse(
(i*m+j+k < n), tvm.tir.Evaluate(m), tvm.tir.Evaluate(n)))))
- ret = tvm.tir.ir_pass.RemoveNoOp(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt))
+ ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
+
assert(isinstance(ret, tvm.tir.Evaluate))
store = tvm.tir.Store(Ab.data,
tvm.tir.Load(dtype, Ab.data, i) + 1,
i + 1)
stmt2 = tvm.tir.SeqStmt([nop(), tvm.tir.SeqStmt([store, nop()])])
- assert(tvm.tir.ir_pass.RemoveNoOp(stmt2) == store)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt2))
+ ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
+ assert(ret == store)
+
# remove zero extent loop
stmt3 = tvm.tir.For(i, 0, 0, 0, 0, store)
- ret = tvm.tir.ir_pass.RemoveNoOp(stmt3)
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt3))
+ ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
assert(isinstance(ret, tvm.tir.Evaluate))
diff --git a/tests/python/unittest/test_tir_pass_rewrite_unsafe_select.py
b/tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py
similarity index 70%
rename from tests/python/unittest/test_tir_pass_rewrite_unsafe_select.py
rename to tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py
index f1e411e..229c11b 100644
--- a/tests/python/unittest/test_tir_pass_rewrite_unsafe_select.py
+++ b/tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py
@@ -23,14 +23,22 @@ def test_rewrite_Select():
A = ib.allocate("float32", 100, name="A", scope="global")
i = te.var("i")
y = tvm.tir.Select(i > 1, A[i-1], 1.0)
- yy = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(y)).value
+
+ mod = tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([i], tvm.tir.Evaluate(y)))
+ yy = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value
z = tvm.tir.Select(
tvm.tir.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
- zz = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(z)).value
+ mod = tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([i], tvm.tir.Evaluate(z)))
+ zz = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value
+
+ a = tvm.tir.Select(tvm.tir.floordiv(i, 4) > 10, y, z)
- a = tvm.tir.Select(tvm.te.floordiv(i, 4) > 10, y, z)
- aa = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(a)).value
+ mod = tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([i], tvm.tir.Evaluate(a)))
+ aa = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value
assert yy.name == "tvm_if_then_else"
assert zz.name == "tvm_if_then_else"
assert isinstance(aa, tvm.tir.Select)
diff --git a/tests/python/unittest/test_arith_stmt_simplify.py
b/tests/python/unittest/test_tir_transform_simplify.py
similarity index 93%
rename from tests/python/unittest/test_arith_stmt_simplify.py
rename to tests/python/unittest/test_tir_transform_simplify.py
index 45f0833..bf53982 100644
--- a/tests/python/unittest/test_arith_stmt_simplify.py
+++ b/tests/python/unittest/test_tir_transform_simplify.py
@@ -27,7 +27,9 @@ def test_stmt_simplify():
A[i] = C[i]
body = tvm.tir.LetStmt(n, 10, ib.get())
- body = tvm.tir.ir_pass.CanonicalSimplify(body)
+ mod = tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([A, C, n], body))
+ body = tvm.tir.transform.Simplify()(mod)["main"].body
assert isinstance(body.body, tvm.tir.Store)
@@ -44,7 +46,9 @@ def test_thread_extent_simplify():
with ib.if_scope(tx + ty < 12):
A[tx] = C[tx + ty]
body = tvm.tir.LetStmt(n, 10, ib.get())
- body = tvm.tir.ir_pass.CanonicalSimplify(body)
+ mod = tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([A, C, n], body))
+ body = tvm.tir.transform.Simplify()(mod)["main"].body
assert isinstance(body.body.body.body, tvm.tir.Store)
diff --git a/tests/python/unittest/test_tir_pass_storage_rewrite.py
b/tests/python/unittest/test_tir_transform_storage_rewrite.py
similarity index 89%
rename from tests/python/unittest/test_tir_pass_storage_rewrite.py
rename to tests/python/unittest/test_tir_transform_storage_rewrite.py
index b36d86b..e4e1b31 100644
--- a/tests/python/unittest/test_tir_pass_storage_rewrite.py
+++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py
@@ -33,9 +33,12 @@ def test_storage_share():
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ mod = tvm.tir.transform.Simplify()(mod)
+ mod = tvm.tir.transform.StorageRewrite()(mod)
+ stmt = mod["main"].body
+
# verify only have one allocations.
# verify inplace folding works
num_alloc = [0]
@@ -72,7 +75,10 @@ def test_alloc_seq():
A[j] = 1.3
body = ib.get()
- body = tvm.tir.ir_pass.StorageRewrite(body)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+ body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
+
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
@@ -129,7 +135,10 @@ def test_alloc_different_dtypes():
body = stmt_generater(dtype_list, length)
offset = offset_generater(dtype_list, length)
- body = tvm.tir.ir_pass.StorageRewrite(body)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], body))
+ body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
+
tvm.tir.ir_pass.PostOrderVisit(body, verify)
length = 1024
@@ -160,9 +169,12 @@ def test_inplace_rule():
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ mod = tvm.tir.transform.Simplify()(mod)
+ mod = tvm.tir.transform.StorageRewrite()(mod)
+ stmt = mod["main"].body
+
# verify only have one allocations.
# verify inplace folding works
num_alloc = [0]
@@ -192,9 +204,12 @@ def test_storage_combine():
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ mod = tvm.tir.transform.Simplify()(mod)
+ mod = tvm.tir.transform.StorageRewrite()(mod)
+ stmt = mod["main"].body
+
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
@@ -226,9 +241,12 @@ def test_storage_share_gpu():
Ab = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='A')
Bb = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='B')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A[0]: Ab, A[-1]: Bb}, 64)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ mod = tvm.tir.transform.Simplify()(mod)
+ mod = tvm.tir.transform.StorageRewrite()(mod)
+ stmt = mod["main"].body
+
alloc_stats = {"global": 0, "shared": 0}
def verify(n):
@@ -248,7 +266,9 @@ def test_parallel_alloc():
A[j] = A[j] + 2
body = ib.get()
- body = tvm.tir.ir_pass.StorageRewrite(body)
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+ body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
+
assert (isinstance(body.body.body, tvm.tir.Allocate))
ib = tvm.tir.ir_builder.create()
@@ -262,7 +282,9 @@ def test_parallel_alloc():
A = ib.allocate("float32", n, name="A", scope="global")
A[j] = A[j] + 2
body = ib.get()
- body = tvm.tir.ir_pass.StorageRewrite(body)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+ body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
assert(isinstance(body.body.body.body.body, tvm.tir.Allocate))
@@ -289,9 +311,12 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits =
1024 * 1024 * 1024):
Cc = tvm.tir.decl_buffer(C.shape, B.dtype, name='C')
Dd = tvm.tir.decl_buffer(D.shape, B.dtype, name='D')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb, C: Cc, D:Dd},
64)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb, Cc, Dd], stmt))
+ mod = tvm.tir.transform.Simplify()(mod)
+ mod = tvm.tir.transform.StorageRewrite()(mod)
+ stmt = mod["main"].body
+
# verify only have one allocations.
# verify inplace folding works
num_alloc = [0]
@@ -381,10 +406,13 @@ def test_inplace_rule3():
B5a = tvm.tir.decl_buffer(B5.shape, B5.dtype, name='B5')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a,
B3: B2a, B4: B4a, B5: B5a, B: Bb}, 64)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- stmt = tvm.tir.ir_pass.StorageRewrite(stmt)
+ stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a,
B3: B3a, B4: B4a, B5: B5a, B: Bb}, 64)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([B0a, B1a, B2a, B3a, B4a,
B5a, Bb], stmt))
+ mod = tvm.tir.transform.Simplify()(mod)
+ mod = tvm.tir.transform.StorageRewrite()(mod)
+ stmt = mod["main"].body
+
# verify only have one allocations.
# verify inplace folding works
def verify(n):
@@ -411,7 +439,10 @@ def test_alloc_seq_type():
A2[j] = A[j]
body = ib.get()
- body = tvm.tir.ir_pass.StorageRewrite(body)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+ body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
+
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
@@ -440,7 +471,10 @@ def test_alloc_seq_type2():
C[j] = 1.2
body = ib.get()
- body = tvm.tir.ir_pass.StorageRewrite(body)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+ body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
+
num_alloc = [0]
def verify(n):
if isinstance(n, tvm.tir.Allocate):
@@ -469,7 +503,9 @@ def test_reuse_small_buffer():
E[j] = C[j]
body = ib.get()
- body = tvm.tir.ir_pass.StorageRewrite(body)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
+ body = tvm.tir.transform.StorageRewrite()(mod)["main"].body
num_alloc = [0]
@@ -519,14 +555,15 @@ def test_large_input():
if __name__ == "__main__":
+ test_storage_share()
test_alloc_seq()
test_alloc_different_dtypes()
test_inplace_rule()
- test_storage_share()
test_parallel_alloc()
test_storage_combine()
test_storage_share_gpu()
test_inplace_rule2()
+
test_exceed_mem()
test_inplace_rule3()
test_alloc_seq_type()
diff --git a/tests/python/unittest/test_tir_pass_unroll.py
b/tests/python/unittest/test_tir_transform_unroll_loop.py
similarity index 84%
rename from tests/python/unittest/test_tir_pass_unroll.py
rename to tests/python/unittest/test_tir_transform_unroll_loop.py
index 165edab..7854835 100644
--- a/tests/python/unittest/test_tir_pass_unroll.py
+++ b/tests/python/unittest/test_tir_transform_unroll_loop.py
@@ -46,7 +46,11 @@ def test_unroll_loop():
wrapped = ib.get()
wrapped = tvm.tir.SeqStmt([wrapped, stmt])
assert isinstance(ret, tvm.tir.For)
- ret = tvm.tir.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], wrapped))
+ ret = tvm.tir.transform.UnrollLoop(0, 8, 0, False)(mod)["main"].body
+
+ # ret = tvm.tir.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False)
assert isinstance(ret[0], tvm.tir.For)
assert ret[0].for_type == tvm.tir.For.Unrolled
assert isinstance(ret[1], tvm.tir.For)
@@ -65,7 +69,11 @@ def test_unroll_fake_loop():
Aptr[j + 1] = Aptr[i] + 1
stmt = ib.get()
- ret = tvm.tir.ir_pass.UnrollLoop(stmt, 8, 0, 1, True)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt))
+ ret = tvm.tir.transform.UnrollLoop(8, 0, 1, False)(mod)["main"].body
+
+ # ret = tvm.tir.ir_pass.UnrollLoop(stmt, 8, 0, 1, True)
assert isinstance(ret[0], tvm.tir.Store)
def test_unroll_single_count_loops():
@@ -78,8 +86,10 @@ def test_unroll_single_count_loops():
stmt = tvm.te.schedule.ScheduleOps(s, dom_map)
# all parameters to UnrolLoops are default values except for
# auto_unroll_max_extent which has been set to 1 (default:0)
- after_unroll_stmt = tvm.tir.ir_pass.UnrollLoop(stmt, 0, 8, 1, True)
- assert after_unroll_stmt == stmt
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt))
+ ret = tvm.tir.transform.UnrollLoop(0, 8, 1, True)(mod)["main"].body
+
+ assert ret == stmt
if __name__ == "__main__":
test_unroll_loop()
diff --git a/tests/python/unittest/test_tir_pass_vectorize.py
b/tests/python/unittest/test_tir_transform_vectorize.py
similarity index 82%
rename from tests/python/unittest/test_tir_pass_vectorize.py
rename to tests/python/unittest/test_tir_transform_vectorize.py
index 2ade843..d7124b6 100644
--- a/tests/python/unittest/test_tir_pass_vectorize.py
+++ b/tests/python/unittest/test_tir_transform_vectorize.py
@@ -28,12 +28,16 @@ def test_vectorize_loop():
stmt = ib.get()
assert isinstance(stmt.body, tvm.tir.For)
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
+ stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
assert isinstance(stmt, tvm.tir.For)
assert not isinstance(stmt.body, tvm.tir.For)
assert isinstance(stmt.body.index, tvm.tir.Ramp)
assert isinstance(stmt.body.value, tvm.tir.Broadcast)
+
def test_vectorize_vector():
dtype = 'int64'
n = te.var('n')
@@ -44,7 +48,10 @@ def test_vectorize_vector():
A[j] = tvm.tir.const(1, A.dtype)
stmt = ib.get()
assert isinstance(stmt.body, tvm.tir.For)
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
+ stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
assert isinstance(stmt, tvm.tir.For)
assert not isinstance(stmt.body, tvm.tir.For)
assert isinstance(stmt.body.index, tvm.tir.Ramp)
@@ -63,13 +70,17 @@ def test_vectorize_with_if():
with ib.if_scope(i < n):
A[i] = 2.0
stmt = ib.get()
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt))
+ stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
assert isinstance(stmt, tvm.tir.IfThenElse)
assert isinstance(stmt.then_case.index, tvm.tir.Ramp)
assert isinstance(stmt.then_case.value, tvm.tir.Add)
assert stmt.then_case.value.dtype == "float32x4"
assert isinstance(stmt.else_case, tvm.tir.For)
+
def test_vectorize_with_le_cond():
n = te.var('n')
ib = tvm.tir.ir_builder.create()
@@ -78,9 +89,13 @@ def test_vectorize_with_le_cond():
with ib.if_scope(i <= n):
A[i] = A[i] + 1
stmt = ib.get()
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
+ stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
assert isinstance(stmt, tvm.tir.For)
+
def test_vectorize_with_ge_cond():
n = te.var('n')
ib = tvm.tir.ir_builder.create()
@@ -89,9 +104,13 @@ def test_vectorize_with_ge_cond():
with ib.if_scope(i >= n):
A[i] = A[i] + 1
stmt = ib.get()
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
+ stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
assert isinstance(stmt, tvm.tir.For)
+
def test_vectorize_if_then_else():
n = te.var('n')
x = te.var('x')
@@ -102,7 +121,10 @@ def test_vectorize_if_then_else():
i > 0,
A[i] + 1, A[i])
stmt = ib.get()
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt))
+ stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
assert isinstance(stmt, tvm.tir.For)
@@ -114,8 +136,12 @@ def test_vectorize_if_then_else():
k > 0,
A[k * 4 + i], 0)
stmt = ib.get()
+
assert isinstance(stmt.body, tvm.tir.For)
- stmt = tvm.tir.ir_pass.VectorizeLoop(stmt)
+
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
+ stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+
assert not isinstance(stmt.body, tvm.tir.For)
assert isinstance(stmt.body.value.args[2], tvm.tir.Broadcast)