This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 542274dde9 [Schedule] Add an optional argument `disable_checks` for 
`Schedule` (#14281)
542274dde9 is described below

commit 542274dde9323ea8b3da2b291b15a6e347565e8c
Author: Zihao Ye <[email protected]>
AuthorDate: Sun Mar 19 07:54:35 2023 -0700

    [Schedule] Add an optional argument `disable_checks` for `Schedule` (#14281)
    
    # Motivation
    Currently, some of the schedule checks are too strict, which makes it hard 
to schedule some workloads such as FlashAttention whose reduction is two-stage 
and does not strictly follows our standard.
    
    This PR adds an optional argument `disable_checks` which mutes some checks. 
The argument defaults to `False` and we can enable it whenever we want to 
disable some `soft` checks (by `soft` we mean if we violate such checks, the 
schedule is not necessarily invalid, and if we violate `hard` checks the 
schedule step is invalid).
    
    In the future, we should collect the `soft` and `hard` checks for all 
schedule primitives. This PR serves for FlashAttention and only cares `bind` 
and some reduction primitives for now.
---
 include/tvm/tir/schedule/schedule.h     | 13 ++++++++-----
 include/tvm/tir/schedule/state.h        |  9 ++++++++-
 python/tvm/tir/schedule/schedule.py     | 21 +++++++++++++++++++++
 python/tvm/tir/schedule/state.py        | 14 ++++++++++++++
 src/tir/schedule/analysis/analysis.cc   |  2 +-
 src/tir/schedule/concrete_schedule.cc   |  6 ++++--
 src/tir/schedule/primitive/for_kind.cc  |  4 +++-
 src/tir/schedule/primitive/reduction.cc | 30 ++++++++++++++++++------------
 src/tir/schedule/schedule.cc            | 10 ++++++----
 src/tir/schedule/state.cc               | 23 ++++++++++++++++-------
 src/tir/schedule/traced_schedule.cc     |  5 +++--
 11 files changed, 102 insertions(+), 35 deletions(-)

diff --git a/include/tvm/tir/schedule/schedule.h 
b/include/tvm/tir/schedule/schedule.h
index 22febfdfed..ef5fd637d3 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -779,14 +779,15 @@ class Schedule : public runtime::ObjectRef {
    * \param debug_mask Do extra correctness checking after the class creation
    * and each time after calling the Replace method.
    * \param error_render_level The level of error rendering
+   * \param enable_check Whether to enable some prequisite checks for schedule 
primitives, it's
+   *   user's duty to guarantee the schedule correctness if we disable the 
checks.
    * \return The concrete schedule created
    * \sa ScheduleDebugMask
-   * \note The checks performed includes:
-   * 1) VerifySRefTree
-   * 2) VerifyCachedFlags
+   * \note The checks performed includes: 1) VerifySRefTree 2) 
VerifyCachedFlags
    */
   TVM_DLL static Schedule Concrete(IRModule mod, 
support::LinearCongruentialEngine::TRandState seed,
-                                   int debug_mask, ScheduleErrorRenderLevel 
error_render_level);
+                                   int debug_mask, ScheduleErrorRenderLevel 
error_render_level,
+                                   bool enable_check = true);
   /*!
    * \brief Construct a traced concrete TensorIR schedule from an IRModule
    * \param mod The IRModule to be scheduled
@@ -794,6 +795,7 @@ class Schedule : public runtime::ObjectRef {
    * \param debug_mask Do extra correctness checking after the class creation
    * and each time after calling the Replace method.
    * \param error_render_level The level of error rendering
+   * \param enable_check Whether to enable prequisite checks for schedule 
primitives.
    * \return The concrete schedule created
    * \sa ScheduleDebugMask
    * \note The checks performed include:
@@ -801,7 +803,8 @@ class Schedule : public runtime::ObjectRef {
    * 2) VerifyCachedFlags
    */
   TVM_DLL static Schedule Traced(IRModule mod, 
support::LinearCongruentialEngine::TRandState seed,
-                                 int debug_mask, ScheduleErrorRenderLevel 
error_render_level);
+                                 int debug_mask, ScheduleErrorRenderLevel 
error_render_level,
+                                 bool enable_check = true);
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, 
ScheduleNode);
 };
 
diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h
index 201d78fe63..a089de2799 100644
--- a/include/tvm/tir/schedule/state.h
+++ b/include/tvm/tir/schedule/state.h
@@ -81,6 +81,7 @@ enum ScheduleDebugMask : uint32_t {
  * 3) The dependency information of each block scope (block_info)
  * 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref)
  * 5) A debug flag, if set, extra checking is enabled (debug_mask)
+ * 6) A check flag, if set, enable prequisite check for schedule primitives 
(enable_check)
  */
 class ScheduleStateNode : public Object {
  public:
@@ -100,12 +101,17 @@ class ScheduleStateNode : public Object {
    * \sa ScheduleDebugMask
    */
   int debug_mask;
+  /*!
+   * \brief Whether to enable prequisite checks for schedule primitives.
+   */
+  bool enable_check;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("mod", &mod);
     // `block_info` is not visited
     // `stmt2ref` is not visited
     v->Visit("debug_mask", &debug_mask);
+    v->Visit("enable_check", &enable_check);
   }
   /*!
    * \brief Replace the part of the AST, as being pointed to by `src_sref`,
@@ -194,8 +200,9 @@ class ScheduleState : public ObjectRef {
    * \param mod The IRModule to be scheduled
    * \param debug_mask Do extra correctness checking after the class creation
    * and each time after calling the Replace method.
+   * \param enable_check Whether enables prerequisite checks for schedule 
primitives.
    */
-  TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0);
+  TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0, bool 
enable_check = true);
 
   /*! \return The mutable pointer to the ScheduleStateNode */
   ScheduleStateNode* get() const { return 
static_cast<ScheduleStateNode*>(data_.get()); }
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index d86cd86ea0..73e9021b6c 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -81,6 +81,14 @@ def _parse_error_render_level(error_render_level: str) -> 
int:
     return _ERROR_RENDER_LEVEL.get(error_render_level)
 
 
+def _parse_enable_checks(enable_checks: bool) -> bool:
+    if not isinstance(enable_checks, bool):
+        raise TypeError(
+            "enable_checks only accepts bool value, got {} 
instead".format(type(enable_checks))
+        )
+    return enable_checks
+
+
 def _parse_seed(seed: Optional[int]) -> int:
     if seed is None:
         return -1
@@ -114,6 +122,7 @@ class Schedule(Object):
         seed: Optional[int] = None,
         debug_mask: Union[str, int] = "none",
         error_render_level: str = "detail",
+        enable_check: bool = True,
     ) -> None:
         """Construct a TensorIR schedule class from an IRModule
 
@@ -137,6 +146,15 @@ class Schedule(Object):
             - "detail": Render a detailed error message, with the TIR and 
error locations printed
             - "fast: Show a simple error message without rendering or string 
manipulation
             - "none": Do not show any error message.
+        enable_check : bool = True
+            The default schedule checks are too strict and might prevent us 
performing some valid
+            schedules. `enable_check` is an argument to control whether we 
enable prerequisite
+            checks for some schedule primitives or not:
+            - true: perform prerequisite check before applying some schedules.
+            - false: do not perform some check before applying schedules, but 
still raise error
+            if schedule fails.
+
+            It's user duty to guarantee schedule correctness if `enable_check` 
is set to `False`.
 
         Note
         ----
@@ -151,6 +169,7 @@ class Schedule(Object):
             _parse_seed(seed),
             _parse_debug_mask(debug_mask),
             _parse_error_render_level(error_render_level),
+            _parse_enable_checks(enable_check),
         )
 
     @staticmethod
@@ -160,6 +179,7 @@ class Schedule(Object):
         seed: Optional[int] = None,
         debug_mask: Union[str, int] = "none",
         error_render_level: str = "detail",
+        enable_check: bool = True,
     ) -> "Schedule":
         """Construct a non-traced TensorIR schedule class from an IRModule."""
         return _ffi_api.ConcreteSchedule(  # type: ignore # pylint: 
disable=no-member
@@ -167,6 +187,7 @@ class Schedule(Object):
             _parse_seed(seed),
             _parse_debug_mask(debug_mask),
             _parse_error_render_level(error_render_level),
+            _parse_enable_checks(enable_check),
         )
 
     ########## Utilities ##########
diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py
index fbf21843e7..dab84b2fcc 100644
--- a/python/tvm/tir/schedule/state.py
+++ b/python/tvm/tir/schedule/state.py
@@ -70,6 +70,14 @@ def _parse_debug_mask(debug_mask: Union[str, int]) -> int:
     return debug_mask
 
 
+def _parse_enable_checks(enable_checks: bool) -> bool:
+    if not isinstance(enable_checks, bool):
+        raise TypeError(
+            "enable_checks only accepts bool value, got {} 
instead".format(type(enable_checks))
+        )
+    return enable_checks
+
+
 @register_object("tir.ScheduleState")
 class ScheduleState(Object):
     """The state of scheduling, which exposes a `Replace` method as
@@ -81,6 +89,7 @@ class ScheduleState(Object):
     3) The dependency information of each block scope (block_info)
     4) A reverse mapping from the AST nodes to that in the sref tree (get_sref)
     5) A debug flag, if set, extra checking is enabled (debug_mask)
+    6) A enable check flag, if False, some prerequisite checks are disabled.
 
     Parameters
     ----------
@@ -89,6 +98,9 @@ class ScheduleState(Object):
     debug_mask : int
         Do extra correctness checking after the object construction
         and each time after calling the Replace method.
+    enable_check : bool
+        Indicates whether we enable prerequisite checks for some schedule 
primitives or not,
+        defaults to `True`.
     """
 
     mod: IRModule
@@ -99,6 +111,7 @@ class ScheduleState(Object):
         mod: Union[PrimFunc, IRModule],
         *,
         debug_mask: Union[str, int] = "none",
+        enable_check: bool = True,
     ) -> None:
         """Construct a schedule state from an IRModule or a PrimFunc
 
@@ -118,6 +131,7 @@ class ScheduleState(Object):
             _ffi_api.ScheduleState,  # type: ignore # pylint: disable=no-member
             _parse_mod(mod),
             _parse_debug_mask(debug_mask),
+            _parse_enable_checks(enable_check),
         )
 
     def get_sref(self, stmt: Union[Block, For]) -> Optional[StmtSRef]:
diff --git a/src/tir/schedule/analysis/analysis.cc 
b/src/tir/schedule/analysis/analysis.cc
index 744801596e..b35d64f125 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -103,7 +103,7 @@ Definition of a scope that is a stage pipeline:
     }
   }
   // Step 2. Handle `require_stage_pipeline`
-  if (require_stage_pipeline) {
+  if (require_stage_pipeline && self->enable_check) {
     bool stage_pipeline = 
self->GetBlockInfo(scope_root_sref).scope->stage_pipeline;
     if (stage_pipeline == false) {
       const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref);
diff --git a/src/tir/schedule/concrete_schedule.cc 
b/src/tir/schedule/concrete_schedule.cc
index 5a9dab4854..75a8fc0a14 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -24,9 +24,10 @@ namespace tvm {
 namespace tir {
 
 Schedule Schedule::Concrete(IRModule mod, 
support::LinearCongruentialEngine::TRandState seed,
-                            int debug_mask, ScheduleErrorRenderLevel 
error_render_level) {
+                            int debug_mask, ScheduleErrorRenderLevel 
error_render_level,
+                            bool enable_check) {
   ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
-  n->state_ = ScheduleState(mod, debug_mask);
+  n->state_ = ScheduleState(mod, debug_mask, enable_check);
   n->error_render_level_ = error_render_level;
   n->symbol_table_ = {};
   n->analyzer_ = std::make_unique<arith::Analyzer>();
@@ -60,6 +61,7 @@ class ScheduleCopier {
     n->block_info = copier.Copy(src_state->block_info);
     n->stmt2ref = copier.Copy(src_state->stmt2ref);
     n->debug_mask = src_state->debug_mask;
+    n->enable_check = src_state->enable_check;
     *new_state = ScheduleState(std::move(n));
     *new_symbol_table = copier.Copy(self->symbol_table_);
   }
diff --git a/src/tir/schedule/primitive/for_kind.cc 
b/src/tir/schedule/primitive/for_kind.cc
index cc8cb55fd3..02d8866e8e 100644
--- a/src/tir/schedule/primitive/for_kind.cc
+++ b/src/tir/schedule/primitive/for_kind.cc
@@ -157,7 +157,9 @@ void ParallelizeComputation(const ScheduleState& self, 
const StmtSRef& loop_sref
    * parallelized/vectorized/bound.
    */
   // Step 1. Check whether the subtree rooted from the `loop` in sref tree has 
compact data flow.
-  CheckSubtreeCompactDataflow(self, loop_sref);
+  if (self->enable_check) {
+    CheckSubtreeCompactDataflow(self, loop_sref);
+  }
 
   // Step 2. Check whether the loop can be parallelized/vectorized/bound with 
regard to each
   // underlying block.
diff --git a/src/tir/schedule/primitive/reduction.cc 
b/src/tir/schedule/primitive/reduction.cc
index bb43df1ce9..d39252f3ce 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -188,17 +188,19 @@ StmtSRef DecomposeReduction(ScheduleState self, const 
StmtSRef& block_sref,
   // Get the outer loops from high to low
   Array<StmtSRef> loops = GetLoops(block_sref);
   const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
-  // Cond 0. Check loop_sref is an ancestor of block_sref
-  if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) {
-    throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block),
-                            "decompose_reduction");
-  }
-  // Cond 1. Check block is reduction
   StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
                                           /*require_stage_pipeline=*/false);
-  CheckReductionBlock(self, block_sref, scope_root_sref);
-  // Cond 2. Check 'loop' is higher than all the loops related to block var of 
type reduction
-  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, 
loops, loop_sref);
+  if (self->enable_check) {
+    // Cond 0. Check loop_sref is an ancestor of block_sref
+    if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) {
+      throw LoopPositionError(self->mod, GetRef<For>(loop), 
GetRef<Block>(block),
+                              "decompose_reduction");
+    }
+    // Cond 1. Check block is reduction
+    CheckReductionBlock(self, block_sref, scope_root_sref);
+    // Cond 2. Check 'loop' is higher than all the loops related to block var 
of type reduction
+    LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, 
loops, loop_sref);
+  }
   // IR Manipulation
   ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
   ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
@@ -1176,7 +1178,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& 
rf_loop_sref, int factor_ax
   const Block& block = block_realize->block;
   StmtSRef scope_root = GetScopeRoot(self, block_sref,  //
                                      /*require_stage_pipeline=*/true);
-  CheckReductionBlock(self, block_sref, scope_root);
+  if (self->enable_check) {
+    CheckReductionBlock(self, block_sref, scope_root);
+  }
   const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop_sref);
   if (rf_loop->kind != ForKind::kSerial) {
     throw NotSerialLoopKindError(self->mod, GetRef<For>(rf_loop));
@@ -1199,8 +1203,10 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& 
rf_loop_sref, int factor_ax
   // - the outermost loop should have the reduction block as its first child 
block;
   // - the outermost loop that is touched by some reduction block iters can 
only have one child
   // block.
-  LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block, 
data_par_loop_vars,
-                                       reduce_loop_vars);
+  if (self->enable_check) {
+    LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block, 
data_par_loop_vars,
+                                         reduce_loop_vars);
+  }
 
   // Step 5. Get the `init` identity and the `update` combiner of the 
reduction. Extract the
   // commutative reducer, combiner lhs and combiner rhs from the reduction 
identity and the
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index cb8b5a1d77..dcaa61e1bb 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -65,15 +65,17 @@ 
TVM_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV
 TVM_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return 
LoopRV(); });
 TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule")
     .set_body_typed([](IRModule mod, 
support::LinearCongruentialEngine::TRandState seed,
-                       int debug_mask, int error_render_level) -> Schedule {
+                       int debug_mask, int error_render_level, bool 
enable_check) -> Schedule {
       return Schedule::Concrete(mod, debug_mask, seed,
-                                
static_cast<ScheduleErrorRenderLevel>(error_render_level));
+                                
static_cast<ScheduleErrorRenderLevel>(error_render_level),
+                                enable_check);
     });
 TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule")
     .set_body_typed([](IRModule mod, 
support::LinearCongruentialEngine::TRandState seed,
-                       int debug_mask, int error_render_level) -> Schedule {
+                       int debug_mask, int error_render_level, bool 
enable_check) -> Schedule {
       return Schedule::Traced(mod, seed, debug_mask,
-                              
static_cast<ScheduleErrorRenderLevel>(error_render_level));
+                              
static_cast<ScheduleErrorRenderLevel>(error_render_level),
+                              enable_check);
     });
 
 /******** (FFI) Lookup random variables ********/
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index a901eff6f2..a7a1c0d482 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -402,16 +402,21 @@ class BlockInfoCollector : private StmtVisitor {
 class StateCreator : private StmtVisitor {
  public:
   /*!
-   * \brief The entry function
-   * \param self The schedule state to be completed
+   * \brief ScheduleState Creator
+   * \param mod The module being scheduled.
+   * \param debug_mask Do extra correctness checking after the class creation
+   * and each time after calling the Replace method.
+   * \param enable_check Whether to enable prequisite checks for schedule 
primitives.
    */
-  static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mask) {
+  static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mask, 
bool enable_check) {
     ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
     ScheduleStateNode* self = n.get();
     // Set `n->mod`
     n->mod = std::move(mod);
     // Set `n->debug_mask`
     n->debug_mask = debug_mask;
+    // Set `n->enable_check`
+    n->enable_check = enable_check;
     // Set `n->stmt2ref` and `n->block_info`
     StateCreator creator(self);
     for (const auto& kv : n->mod->functions) {
@@ -426,6 +431,10 @@ class StateCreator : private StmtVisitor {
   }
 
  private:
+  /*!
+   * \brief The entry function
+   * \param self The schedule state to be completed
+   */
   explicit StateCreator(ScheduleStateNode* self) : self_(self) {}
 
   /*!
@@ -481,9 +490,9 @@ class StateCreator : private StmtVisitor {
 
 /**************** Constructor ****************/
 
-ScheduleState::ScheduleState(IRModule mod, int debug_mask) {
+ScheduleState::ScheduleState(IRModule mod, int debug_mask, bool enable_check) {
   CHECK_GE(debug_mask, -1) << "ValueError: negative `debug_mask` other than -1 
is not supported";
-  data_ = StateCreator::Create(mod, debug_mask);
+  data_ = StateCreator::Create(mod, debug_mask, enable_check);
 }
 
 /**************** Replace ****************/
@@ -1108,8 +1117,8 @@ TVM_DLL Array<Bool> GetCachedFlags(const ScheduleState& 
self, const StmtSRef& bl
 
 TVM_REGISTER_NODE_TYPE(ScheduleStateNode);
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState")
-    .set_body_typed([](IRModule mod, int debug_mask) -> ScheduleState {
-      return ScheduleState(mod, debug_mask);
+    .set_body_typed([](IRModule mod, int debug_mask, bool enable_check) -> 
ScheduleState {
+      return ScheduleState(mod, debug_mask, enable_check);
     });
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope")
     .set_body_method<ScheduleState>(&ScheduleStateNode::GetBlockScope);
diff --git a/src/tir/schedule/traced_schedule.cc 
b/src/tir/schedule/traced_schedule.cc
index a5cb66a0cb..1ccc82f302 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -22,9 +22,10 @@ namespace tvm {
 namespace tir {
 
 Schedule Schedule::Traced(IRModule mod, 
support::LinearCongruentialEngine::TRandState seed,
-                          int debug_mask, ScheduleErrorRenderLevel 
error_render_level) {
+                          int debug_mask, ScheduleErrorRenderLevel 
error_render_level,
+                          bool enable_check) {
   ObjectPtr<TracedScheduleNode> n = make_object<TracedScheduleNode>();
-  n->state_ = ScheduleState(mod, debug_mask);
+  n->state_ = ScheduleState(mod, debug_mask, enable_check);
   n->error_render_level_ = error_render_level;
   n->symbol_table_ = {};
   n->analyzer_ = std::make_unique<arith::Analyzer>();

Reply via email to