This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 542274dde9 [Schedule] Add an optional argument `disable_checks` for
`Schedule` (#14281)
542274dde9 is described below
commit 542274dde9323ea8b3da2b291b15a6e347565e8c
Author: Zihao Ye <[email protected]>
AuthorDate: Sun Mar 19 07:54:35 2023 -0700
[Schedule] Add an optional argument `disable_checks` for `Schedule` (#14281)
# Motivation
Currently, some of the schedule checks are too strict, which makes it hard
to schedule some workloads such as FlashAttention whose reduction is two-stage
and does not strictly follows our standard.
This PR adds an optional argument `disable_checks` which mutes some checks.
The argument defaults to `False` and we can enable it whenever we want to
disable some `soft` checks (by `soft` we mean if we violate such checks, the
schedule is not necessarily invalid, and if we violate `hard` checks the
schedule step is invalid).
In the future, we should collect the `soft` and `hard` checks for all
schedule primitives. This PR serves for FlashAttention and only cares `bind`
and some reduction primitives for now.
---
include/tvm/tir/schedule/schedule.h | 13 ++++++++-----
include/tvm/tir/schedule/state.h | 9 ++++++++-
python/tvm/tir/schedule/schedule.py | 21 +++++++++++++++++++++
python/tvm/tir/schedule/state.py | 14 ++++++++++++++
src/tir/schedule/analysis/analysis.cc | 2 +-
src/tir/schedule/concrete_schedule.cc | 6 ++++--
src/tir/schedule/primitive/for_kind.cc | 4 +++-
src/tir/schedule/primitive/reduction.cc | 30 ++++++++++++++++++------------
src/tir/schedule/schedule.cc | 10 ++++++----
src/tir/schedule/state.cc | 23 ++++++++++++++++-------
src/tir/schedule/traced_schedule.cc | 5 +++--
11 files changed, 102 insertions(+), 35 deletions(-)
diff --git a/include/tvm/tir/schedule/schedule.h
b/include/tvm/tir/schedule/schedule.h
index 22febfdfed..ef5fd637d3 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -779,14 +779,15 @@ class Schedule : public runtime::ObjectRef {
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
+ * \param enable_check Whether to enable some prequisite checks for schedule
primitives, it's
+ * user's duty to guarantee the schedule correctness if we disable the
checks.
* \return The concrete schedule created
* \sa ScheduleDebugMask
- * \note The checks performed includes:
- * 1) VerifySRefTree
- * 2) VerifyCachedFlags
+ * \note The checks performed includes: 1) VerifySRefTree 2)
VerifyCachedFlags
*/
TVM_DLL static Schedule Concrete(IRModule mod,
support::LinearCongruentialEngine::TRandState seed,
- int debug_mask, ScheduleErrorRenderLevel
error_render_level);
+ int debug_mask, ScheduleErrorRenderLevel
error_render_level,
+ bool enable_check = true);
/*!
* \brief Construct a traced concrete TensorIR schedule from an IRModule
* \param mod The IRModule to be scheduled
@@ -794,6 +795,7 @@ class Schedule : public runtime::ObjectRef {
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
+ * \param enable_check Whether to enable prequisite checks for schedule
primitives.
* \return The concrete schedule created
* \sa ScheduleDebugMask
* \note The checks performed include:
@@ -801,7 +803,8 @@ class Schedule : public runtime::ObjectRef {
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Traced(IRModule mod,
support::LinearCongruentialEngine::TRandState seed,
- int debug_mask, ScheduleErrorRenderLevel
error_render_level);
+ int debug_mask, ScheduleErrorRenderLevel
error_render_level,
+ bool enable_check = true);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef,
ScheduleNode);
};
diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h
index 201d78fe63..a089de2799 100644
--- a/include/tvm/tir/schedule/state.h
+++ b/include/tvm/tir/schedule/state.h
@@ -81,6 +81,7 @@ enum ScheduleDebugMask : uint32_t {
* 3) The dependency information of each block scope (block_info)
* 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref)
* 5) A debug flag, if set, extra checking is enabled (debug_mask)
+ * 6) A check flag, if set, enable prequisite check for schedule primitives
(enable_check)
*/
class ScheduleStateNode : public Object {
public:
@@ -100,12 +101,17 @@ class ScheduleStateNode : public Object {
* \sa ScheduleDebugMask
*/
int debug_mask;
+ /*!
+ * \brief Whether to enable prequisite checks for schedule primitives.
+ */
+ bool enable_check;
void VisitAttrs(AttrVisitor* v) {
v->Visit("mod", &mod);
// `block_info` is not visited
// `stmt2ref` is not visited
v->Visit("debug_mask", &debug_mask);
+ v->Visit("enable_check", &enable_check);
}
/*!
* \brief Replace the part of the AST, as being pointed to by `src_sref`,
@@ -194,8 +200,9 @@ class ScheduleState : public ObjectRef {
* \param mod The IRModule to be scheduled
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
+ * \param enable_check Whether enables prerequisite checks for schedule
primitives.
*/
- TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0);
+ TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0, bool
enable_check = true);
/*! \return The mutable pointer to the ScheduleStateNode */
ScheduleStateNode* get() const { return
static_cast<ScheduleStateNode*>(data_.get()); }
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index d86cd86ea0..73e9021b6c 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -81,6 +81,14 @@ def _parse_error_render_level(error_render_level: str) ->
int:
return _ERROR_RENDER_LEVEL.get(error_render_level)
+def _parse_enable_checks(enable_checks: bool) -> bool:
+ if not isinstance(enable_checks, bool):
+ raise TypeError(
+ "enable_checks only accepts bool value, got {}
instead".format(type(enable_checks))
+ )
+ return enable_checks
+
+
def _parse_seed(seed: Optional[int]) -> int:
if seed is None:
return -1
@@ -114,6 +122,7 @@ class Schedule(Object):
seed: Optional[int] = None,
debug_mask: Union[str, int] = "none",
error_render_level: str = "detail",
+ enable_check: bool = True,
) -> None:
"""Construct a TensorIR schedule class from an IRModule
@@ -137,6 +146,15 @@ class Schedule(Object):
- "detail": Render a detailed error message, with the TIR and
error locations printed
- "fast: Show a simple error message without rendering or string
manipulation
- "none": Do not show any error message.
+ enable_check : bool = True
+ The default schedule checks are too strict and might prevent us
performing some valid
+ schedules. `enable_check` is an argument to control whether we
enable prerequisite
+ checks for some schedule primitives or not:
+ - true: perform prerequisite check before applying some schedules.
+ - false: do not perform some check before applying schedules, but
still raise error
+ if schedule fails.
+
+ It's user duty to guarantee schedule correctness if `enable_check`
is set to `False`.
Note
----
@@ -151,6 +169,7 @@ class Schedule(Object):
_parse_seed(seed),
_parse_debug_mask(debug_mask),
_parse_error_render_level(error_render_level),
+ _parse_enable_checks(enable_check),
)
@staticmethod
@@ -160,6 +179,7 @@ class Schedule(Object):
seed: Optional[int] = None,
debug_mask: Union[str, int] = "none",
error_render_level: str = "detail",
+ enable_check: bool = True,
) -> "Schedule":
"""Construct a non-traced TensorIR schedule class from an IRModule."""
return _ffi_api.ConcreteSchedule( # type: ignore # pylint:
disable=no-member
@@ -167,6 +187,7 @@ class Schedule(Object):
_parse_seed(seed),
_parse_debug_mask(debug_mask),
_parse_error_render_level(error_render_level),
+ _parse_enable_checks(enable_check),
)
########## Utilities ##########
diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py
index fbf21843e7..dab84b2fcc 100644
--- a/python/tvm/tir/schedule/state.py
+++ b/python/tvm/tir/schedule/state.py
@@ -70,6 +70,14 @@ def _parse_debug_mask(debug_mask: Union[str, int]) -> int:
return debug_mask
+def _parse_enable_checks(enable_checks: bool) -> bool:
+ if not isinstance(enable_checks, bool):
+ raise TypeError(
+ "enable_checks only accepts bool value, got {}
instead".format(type(enable_checks))
+ )
+ return enable_checks
+
+
@register_object("tir.ScheduleState")
class ScheduleState(Object):
"""The state of scheduling, which exposes a `Replace` method as
@@ -81,6 +89,7 @@ class ScheduleState(Object):
3) The dependency information of each block scope (block_info)
4) A reverse mapping from the AST nodes to that in the sref tree (get_sref)
5) A debug flag, if set, extra checking is enabled (debug_mask)
+ 6) A enable check flag, if False, some prerequisite checks are disabled.
Parameters
----------
@@ -89,6 +98,9 @@ class ScheduleState(Object):
debug_mask : int
Do extra correctness checking after the object construction
and each time after calling the Replace method.
+ enable_check : bool
+ Indicates whether we enable prerequisite checks for some schedule
primitives or not,
+ defaults to `True`.
"""
mod: IRModule
@@ -99,6 +111,7 @@ class ScheduleState(Object):
mod: Union[PrimFunc, IRModule],
*,
debug_mask: Union[str, int] = "none",
+ enable_check: bool = True,
) -> None:
"""Construct a schedule state from an IRModule or a PrimFunc
@@ -118,6 +131,7 @@ class ScheduleState(Object):
_ffi_api.ScheduleState, # type: ignore # pylint: disable=no-member
_parse_mod(mod),
_parse_debug_mask(debug_mask),
+ _parse_enable_checks(enable_check),
)
def get_sref(self, stmt: Union[Block, For]) -> Optional[StmtSRef]:
diff --git a/src/tir/schedule/analysis/analysis.cc
b/src/tir/schedule/analysis/analysis.cc
index 744801596e..b35d64f125 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -103,7 +103,7 @@ Definition of a scope that is a stage pipeline:
}
}
// Step 2. Handle `require_stage_pipeline`
- if (require_stage_pipeline) {
+ if (require_stage_pipeline && self->enable_check) {
bool stage_pipeline =
self->GetBlockInfo(scope_root_sref).scope->stage_pipeline;
if (stage_pipeline == false) {
const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref);
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index 5a9dab4854..75a8fc0a14 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -24,9 +24,10 @@ namespace tvm {
namespace tir {
Schedule Schedule::Concrete(IRModule mod,
support::LinearCongruentialEngine::TRandState seed,
- int debug_mask, ScheduleErrorRenderLevel
error_render_level) {
+ int debug_mask, ScheduleErrorRenderLevel
error_render_level,
+ bool enable_check) {
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
- n->state_ = ScheduleState(mod, debug_mask);
+ n->state_ = ScheduleState(mod, debug_mask, enable_check);
n->error_render_level_ = error_render_level;
n->symbol_table_ = {};
n->analyzer_ = std::make_unique<arith::Analyzer>();
@@ -60,6 +61,7 @@ class ScheduleCopier {
n->block_info = copier.Copy(src_state->block_info);
n->stmt2ref = copier.Copy(src_state->stmt2ref);
n->debug_mask = src_state->debug_mask;
+ n->enable_check = src_state->enable_check;
*new_state = ScheduleState(std::move(n));
*new_symbol_table = copier.Copy(self->symbol_table_);
}
diff --git a/src/tir/schedule/primitive/for_kind.cc
b/src/tir/schedule/primitive/for_kind.cc
index cc8cb55fd3..02d8866e8e 100644
--- a/src/tir/schedule/primitive/for_kind.cc
+++ b/src/tir/schedule/primitive/for_kind.cc
@@ -157,7 +157,9 @@ void ParallelizeComputation(const ScheduleState& self,
const StmtSRef& loop_sref
* parallelized/vectorized/bound.
*/
// Step 1. Check whether the subtree rooted from the `loop` in sref tree has
compact data flow.
- CheckSubtreeCompactDataflow(self, loop_sref);
+ if (self->enable_check) {
+ CheckSubtreeCompactDataflow(self, loop_sref);
+ }
// Step 2. Check whether the loop can be parallelized/vectorized/bound with
regard to each
// underlying block.
diff --git a/src/tir/schedule/primitive/reduction.cc
b/src/tir/schedule/primitive/reduction.cc
index bb43df1ce9..d39252f3ce 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -188,17 +188,19 @@ StmtSRef DecomposeReduction(ScheduleState self, const
StmtSRef& block_sref,
// Get the outer loops from high to low
Array<StmtSRef> loops = GetLoops(block_sref);
const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
- // Cond 0. Check loop_sref is an ancestor of block_sref
- if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) {
- throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block),
- "decompose_reduction");
- }
- // Cond 1. Check block is reduction
StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/false);
- CheckReductionBlock(self, block_sref, scope_root_sref);
- // Cond 2. Check 'loop' is higher than all the loops related to block var of
type reduction
- LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize,
loops, loop_sref);
+ if (self->enable_check) {
+ // Cond 0. Check loop_sref is an ancestor of block_sref
+ if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) {
+ throw LoopPositionError(self->mod, GetRef<For>(loop),
GetRef<Block>(block),
+ "decompose_reduction");
+ }
+ // Cond 1. Check block is reduction
+ CheckReductionBlock(self, block_sref, scope_root_sref);
+ // Cond 2. Check 'loop' is higher than all the loops related to block var
of type reduction
+ LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize,
loops, loop_sref);
+ }
// IR Manipulation
ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
@@ -1176,7 +1178,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef&
rf_loop_sref, int factor_ax
const Block& block = block_realize->block;
StmtSRef scope_root = GetScopeRoot(self, block_sref, //
/*require_stage_pipeline=*/true);
- CheckReductionBlock(self, block_sref, scope_root);
+ if (self->enable_check) {
+ CheckReductionBlock(self, block_sref, scope_root);
+ }
const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop_sref);
if (rf_loop->kind != ForKind::kSerial) {
throw NotSerialLoopKindError(self->mod, GetRef<For>(rf_loop));
@@ -1199,8 +1203,10 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef&
rf_loop_sref, int factor_ax
// - the outermost loop should have the reduction block as its first child
block;
// - the outermost loop that is touched by some reduction block iters can
only have one child
// block.
- LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block,
data_par_loop_vars,
- reduce_loop_vars);
+ if (self->enable_check) {
+ LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block,
data_par_loop_vars,
+ reduce_loop_vars);
+ }
// Step 5. Get the `init` identity and the `update` combiner of the
reduction. Extract the
// commutative reducer, combiner lhs and combiner rhs from the reduction
identity and the
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index cb8b5a1d77..dcaa61e1bb 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -65,15 +65,17 @@
TVM_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV
TVM_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return
LoopRV(); });
TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule")
.set_body_typed([](IRModule mod,
support::LinearCongruentialEngine::TRandState seed,
- int debug_mask, int error_render_level) -> Schedule {
+ int debug_mask, int error_render_level, bool
enable_check) -> Schedule {
return Schedule::Concrete(mod, debug_mask, seed,
-
static_cast<ScheduleErrorRenderLevel>(error_render_level));
+
static_cast<ScheduleErrorRenderLevel>(error_render_level),
+ enable_check);
});
TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule")
.set_body_typed([](IRModule mod,
support::LinearCongruentialEngine::TRandState seed,
- int debug_mask, int error_render_level) -> Schedule {
+ int debug_mask, int error_render_level, bool
enable_check) -> Schedule {
return Schedule::Traced(mod, seed, debug_mask,
-
static_cast<ScheduleErrorRenderLevel>(error_render_level));
+
static_cast<ScheduleErrorRenderLevel>(error_render_level),
+ enable_check);
});
/******** (FFI) Lookup random variables ********/
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index a901eff6f2..a7a1c0d482 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -402,16 +402,21 @@ class BlockInfoCollector : private StmtVisitor {
class StateCreator : private StmtVisitor {
public:
/*!
- * \brief The entry function
- * \param self The schedule state to be completed
+ * \brief ScheduleState Creator
+ * \param mod The module being scheduled.
+ * \param debug_mask Do extra correctness checking after the class creation
+ * and each time after calling the Replace method.
+ * \param enable_check Whether to enable prequisite checks for schedule
primitives.
*/
- static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mask) {
+ static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mask,
bool enable_check) {
ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
ScheduleStateNode* self = n.get();
// Set `n->mod`
n->mod = std::move(mod);
// Set `n->debug_mask`
n->debug_mask = debug_mask;
+ // Set `n->enable_check`
+ n->enable_check = enable_check;
// Set `n->stmt2ref` and `n->block_info`
StateCreator creator(self);
for (const auto& kv : n->mod->functions) {
@@ -426,6 +431,10 @@ class StateCreator : private StmtVisitor {
}
private:
+ /*!
+ * \brief The entry function
+ * \param self The schedule state to be completed
+ */
explicit StateCreator(ScheduleStateNode* self) : self_(self) {}
/*!
@@ -481,9 +490,9 @@ class StateCreator : private StmtVisitor {
/**************** Constructor ****************/
-ScheduleState::ScheduleState(IRModule mod, int debug_mask) {
+ScheduleState::ScheduleState(IRModule mod, int debug_mask, bool enable_check) {
CHECK_GE(debug_mask, -1) << "ValueError: negative `debug_mask` other than -1
is not supported";
- data_ = StateCreator::Create(mod, debug_mask);
+ data_ = StateCreator::Create(mod, debug_mask, enable_check);
}
/**************** Replace ****************/
@@ -1108,8 +1117,8 @@ TVM_DLL Array<Bool> GetCachedFlags(const ScheduleState&
self, const StmtSRef& bl
TVM_REGISTER_NODE_TYPE(ScheduleStateNode);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState")
- .set_body_typed([](IRModule mod, int debug_mask) -> ScheduleState {
- return ScheduleState(mod, debug_mask);
+ .set_body_typed([](IRModule mod, int debug_mask, bool enable_check) ->
ScheduleState {
+ return ScheduleState(mod, debug_mask, enable_check);
});
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope")
.set_body_method<ScheduleState>(&ScheduleStateNode::GetBlockScope);
diff --git a/src/tir/schedule/traced_schedule.cc
b/src/tir/schedule/traced_schedule.cc
index a5cb66a0cb..1ccc82f302 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -22,9 +22,10 @@ namespace tvm {
namespace tir {
Schedule Schedule::Traced(IRModule mod,
support::LinearCongruentialEngine::TRandState seed,
- int debug_mask, ScheduleErrorRenderLevel
error_render_level) {
+ int debug_mask, ScheduleErrorRenderLevel
error_render_level,
+ bool enable_check) {
ObjectPtr<TracedScheduleNode> n = make_object<TracedScheduleNode>();
- n->state_ = ScheduleState(mod, debug_mask);
+ n->state_ = ScheduleState(mod, debug_mask, enable_check);
n->error_render_level_ = error_render_level;
n->symbol_table_ = {};
n->analyzer_ = std::make_unique<arith::Analyzer>();