This is an automated email from the ASF dual-hosted git repository.

bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d4a396825b [TIR] Add schedule primitive TransformBlockLayout (#11485)
d4a396825b is described below

commit d4a396825bead4c617a4867a10dd6eff7797add4
Author: Wuwei Lin <[email protected]>
AuthorDate: Sun May 29 09:12:17 2022 -0700

    [TIR] Add schedule primitive TransformBlockLayout (#11485)
    
    * [TIR] Add schedule primitive TransformBlockLayout
    
    * fixup! [TIR] Add schedule primitive TransformBlockLayout
    
    Fix doc
---
 include/tvm/tir/schedule/schedule.h                |  10 +
 python/tvm/tir/schedule/schedule.py                |  61 +++++
 src/tir/schedule/analysis.h                        |  11 +
 src/tir/schedule/analysis/analysis.cc              |  29 ++
 src/tir/schedule/concrete_schedule.cc              |   8 +
 src/tir/schedule/concrete_schedule.h               |   1 +
 src/tir/schedule/primitive.h                       |  12 +
 .../schedule/primitive/layout_transformation.cc    | 304 +++++++++++++++++++++
 src/tir/schedule/primitive/loop_transformation.cc  |  29 +-
 src/tir/schedule/schedule.cc                       |   2 +
 src/tir/schedule/traced_schedule.cc                |  10 +
 src/tir/schedule/traced_schedule.h                 |   1 +
 src/tir/schedule/transform.cc                      |  31 +++
 src/tir/schedule/transform.h                       |  39 +++
 .../unittest/test_tir_schedule_transform_layout.py | 113 ++++++++
 15 files changed, 635 insertions(+), 26 deletions(-)

diff --git a/include/tvm/tir/schedule/schedule.h 
b/include/tvm/tir/schedule/schedule.h
index 18e15d1670..48014280a5 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -545,6 +545,16 @@ class ScheduleNode : public runtime::Object {
   virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
                                BufferIndexType buffer_index_type, const 
IndexMap& index_map) = 0;
 
+  /*!
+   * \brief Apply a transformation represented by IndexMap to block
+   * \details The block iters and the block body are transformed by the given 
index_map.
+   * Outer loops corresponding to each new block iter are regenerated.
+   * The index_map is required to be bijective affine since we need its 
inverse mapping.
+   * \param block_rv The block to be transformed
+   * \param index_map The transformation to apply.
+   */
+  virtual void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& 
index_map) = 0;
+
   /*!
    * \brief Set the axis separator of a buffer, where the buffer is specified 
by a block and a read
    * or write index
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index dc687b1eae..f86228848b 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2286,6 +2286,67 @@ class Schedule(Object):
                 self, block, buffer_index, buffer_index_type_enum, 
axis_separators
             )
 
+    @type_checked
+    def transform_block_layout(
+        self,
+        block: BlockRV,
+        index_map: Union[IndexMap, Callable],
+    ) -> None:
+        """Apply a transformation represented by IndexMap to block
+
+        Parameters
+        ----------
+        block : BlockRV
+            The block to be transformed
+
+        index_map : Union[IndexMap, Callable]
+            The transformation to apply.
+
+        Examples
+        --------
+
+        Before transform_block_layout, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def before_transform_block_layout(
+                A: T.Buffer[(16, 16), "float32"],
+                B: T.Buffer[(16, 16), "float32"]
+            ) -> None:
+                for i, j in T.grid(16, 16):
+                    with T.block("B"):
+                        vi, vj = T.axis.remap("SS", [i, j])
+                        B[vi, vj] = A[vi, vj] * 2.0
+
+        Create the schedule and do transform_block_layout:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_transform_block_layout)
+            sch.transform_block_layout(sch.get_block("B"), lambda i, j: (i * 
16 + j,))
+            print(sch.mod["main"].script())
+
+        After applying transform_block_layout, the IR becomes:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def after_transform_block_layout(
+                A: T.Buffer[(16, 16), "float32"],
+                B: T.Buffer[(16, 16), "float32"]
+            ) -> None:
+                for i in range(256):
+                    with T.block("B"):
+                        vi, = T.axis.remap("S", [i])
+                        B[vi // 16, vi % 16] = A[vi // 16, vi % 16] * 2.0
+        """
+        if callable(index_map):
+            index_map = IndexMap.from_func(index_map)
+        _ffi_api.ScheduleTransformBlockLayout(  # type: ignore # pylint: 
disable=no-member
+            self, block, index_map
+        )
+
     @type_checked
     def set_axis_separator(
         self,
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index c9c3d72ae0..0574cfefad 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -277,6 +277,17 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& 
block_realize,
                                 std::unordered_set<const VarNode*>* 
data_par_vars,
                                 std::unordered_set<const VarNode*>* 
reduce_vars);
 
+/******** Loop properties ********/
+/*!
+ * \brief Check the loop starts with zero.
+ * \param self The schedule state
+ * \param loop_sref The StmtSRef that points to the loop to be checked
+ * \param analyzer The arithmetic analyzer
+ * \throw ScheduleError If the loop doesn't starts with zero.
+ */
+void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& 
loop_sref,
+                             arith::Analyzer* analyzer);
+
 /******** Block-loop relation ********/
 
 /*!
diff --git a/src/tir/schedule/analysis/analysis.cc 
b/src/tir/schedule/analysis/analysis.cc
index 4777ee2657..c4719015da 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -686,6 +686,35 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& 
block_realize,
   return has_block_vars_of_other_types;
 }
 
+/******** Loop properties ********/
+
+void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& 
loop_sref,
+                             arith::Analyzer* analyzer) {
+  class LoopNotStartWithZeroError : public ScheduleError {
+   public:
+    explicit LoopNotStartWithZeroError(IRModule mod, For loop)
+        : mod_(mod), loop_(std::move(loop)) {}
+
+    String FastErrorString() const final {
+      return "ScheduleError: The primitive only supports loop starting with 0";
+    }
+
+    String DetailRenderTemplate() const final {
+      return "The loop {0} does not start with 0, which is not supported";
+    }
+
+    IRModule mod() const final { return mod_; }
+    Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }
+
+    IRModule mod_;
+    For loop_;
+  };
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  if (!analyzer->CanProve(loop->min == 0)) {
+    throw LoopNotStartWithZeroError(self->mod, GetRef<For>(loop));
+  }
+}
+
 /******** Block-loop relation ********/
 
 Array<StmtSRef> GetChildBlockSRefOnSRefTree(const ScheduleState& self,
diff --git a/src/tir/schedule/concrete_schedule.cc 
b/src/tir/schedule/concrete_schedule.cc
index 7b953220f2..8066d85a8e 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -693,6 +693,14 @@ void ConcreteScheduleNode::TransformLayout(const BlockRV& 
block_rv, int buffer_i
   TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_);
 }
 
+void ConcreteScheduleNode::TransformBlockLayout(const BlockRV& block_rv,
+                                                const IndexMap& index_map) {
+  TVM_TIR_SCHEDULE_BEGIN();
+  tir::TransformBlockLayout(state_, this->GetSRef(block_rv), index_map);
+  this->state_->DebugVerify();
+  TVM_TIR_SCHEDULE_END("transform_block_layout", this->error_render_level_);
+}
+
 void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int 
buffer_index,
                                             BufferIndexType buffer_index_type,
                                             const Array<IntImm>& 
axis_separators) {
diff --git a/src/tir/schedule/concrete_schedule.h 
b/src/tir/schedule/concrete_schedule.h
index 9293aa3493..8e83aac2ce 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -134,6 +134,7 @@ class ConcreteScheduleNode : public ScheduleNode {
   /******** Schedule: Layout transformation ********/
   void TransformLayout(const BlockRV& block_rv, int buffer_index, 
BufferIndexType buffer_index_type,
                        const IndexMap& index_map) override;
+  void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& 
index_map) override;
   void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
                         BufferIndexType buffer_index_type,
                         const Array<IntImm>& axis_separators) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index d55b896934..50dedf71ff 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -442,6 +442,18 @@ TVM_DLL void Unannotate(ScheduleState self, const 
StmtSRef& sref, const String&
 TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, 
int buffer_index,
                              BufferIndexType buffer_index_type, const 
IndexMap& index_map);
 
+/*!
+ * \brief Apply a transformation represented by IndexMap to block
+ * \details The block iters and the block body are transformed by the given 
index_map.
+ * Outer loops corresponding to each new block iter are regenerated.
+ * The index_map is required to be bijective affine since we need its inverse 
mapping.
+ * \param self The state of the schedule
+ * \param block_sref The block sref that refers to the block to be transformed
+ * \param index_map The transformation to apply.
+ */
+TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& 
block_sref,
+                                  const IndexMap& index_map);
+
 /******** Schedule: Misc ********/
 
 }  // namespace tir
diff --git a/src/tir/schedule/primitive/layout_transformation.cc 
b/src/tir/schedule/primitive/layout_transformation.cc
index cf95665ee8..6da796fc95 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -192,6 +192,269 @@ void TransformLayout(ScheduleState self, const StmtSRef& 
block_sref, int buffer_
   self->Replace(scope_sref, new_scope_block, block_sref_reuse);
 }
 
+/*!
+ * \brief Detect the block iter type assoicated with the expression
+ *
+ * This function collects block iters in the expression and check if the block 
iters have the same
+ * iter type. The detected iter type is the iter type of the block iters in 
the expression
+ * if they have the same iter type, otherwise the detected iter type will be 
kOpaque.
+ *
+ * \param expr The expression
+ * \param block_iter_type_map The mapping from block iter to iter type
+ * \return The detected block iter type
+ */
+IterVarType DetectNewBlockIterType(
+    const PrimExpr& expr,
+    const std::unordered_map<const VarNode*, IterVarType>& 
block_iter_type_map) {
+  IterVarType result{kOpaque};
+  bool found = false;
+  PostOrderVisit(expr, [&](const ObjectRef& obj) {
+    if (const VarNode* var = obj.as<VarNode>()) {
+      auto it = block_iter_type_map.find(var);
+      if (it != block_iter_type_map.end()) {
+        if (!found) {
+          found = true;
+          result = it->second;
+        } else if (result != it->second) {
+          result = kOpaque;
+          return false;
+        }
+      }
+    }
+    return true;
+  });
+  return result;
+}
+
+class NotBijectiveAffineIndexMapError : public ScheduleError {
+ public:
+  NotBijectiveAffineIndexMapError(IRModule mod, IndexMap index_map)
+      : mod_(std::move(mod)), index_map_(std::move(index_map)) {}
+  String FastErrorString() const final {
+    return "ScheduleError: The index map is not bijective affine.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The index map " << index_map_->ToPythonString() << " is not 
bijective affine.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+
+  Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+ private:
+  IRModule mod_;
+  IndexMap index_map_;
+};
+
+class IndexMapNotApplicableToBlockIterError : public ScheduleError {
+ public:
+  static void Check(const IRModule mod, const Block& block, const IndexMap& 
index_map) {
+    if (index_map->initial_indices.size() != block->iter_vars.size()) {
+      throw IndexMapNotApplicableToBlockIterError(mod, block, index_map);
+    }
+  }
+  explicit IndexMapNotApplicableToBlockIterError(IRModule mod, Block block, 
IndexMap index_map)
+      : mod_(std::move(mod)), block_(std::move(block)), 
index_map_(std::move(index_map)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The index map can't be applied to block iters 
because the number of "
+           "parameters mismatch.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The index map " << index_map_->ToPythonString()
+       << " can't be applied to block iters of {0} because the number of 
parameters mismatch. "
+          "Expected: "
+       << index_map_->initial_indices.size() << ", actual: " << 
block_->iter_vars.size();
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+
+  Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+ private:
+  IRModule mod_;
+  Block block_;
+  IndexMap index_map_;
+};
+
+class NotTrivialBindingError : public ScheduleError {
+ public:
+  explicit NotTrivialBindingError(IRModule mod, Block block)
+      : mod_(std::move(mod)), block_(std::move(block)) {}
+
+  static void CheckBlockHasTrivialBinding(const IRModule& mod, const 
BlockRealize& block_realize,
+                                          std::unordered_set<const VarNode*> 
outer_loop_vars) {
+    // Step 2: Check all the binding values are loops vars
+    for (const PrimExpr& iter_value : block_realize->iter_values) {
+      const VarNode* loop_var = iter_value.as<VarNode>();
+      if (!loop_var || !outer_loop_vars.count(loop_var)) {
+        throw NotTrivialBindingError(mod, block_realize->block);
+      }
+    }
+  }
+
+  String FastErrorString() const final {
+    return "ScheduleError: The binding values of the block are not variables 
of outer loops.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The binding values of the {0} are not variables of outer loops.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+ private:
+  IRModule mod_;
+  Block block_;
+};
+
+class OpaqueNewIterTypeError : public ScheduleError {
+ public:
+  explicit OpaqueNewIterTypeError(IRModule mod, Block block, PrimExpr 
iter_value)
+      : mod_(std::move(mod)), block_(std::move(block)), 
iter_value_(std::move(iter_value)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: Cannot detect the new block iter type because it 
contains more than one "
+           "type of original iter vars.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "Cannot detect the block iter type for new iter value " << 
PrettyPrint(iter_value_)
+       << " in {0} because it contains more than one type of original iter 
vars.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+ private:
+  IRModule mod_;
+  Block block_;
+  PrimExpr iter_value_;
+};
+
+void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
+                          const IndexMap& index_map) {
+  const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
+  const Block& block = GetRef<Block>(block_ptr);
+  arith::Analyzer analyzer;
+
+  // Step 1: Collect outer loops and loop vars
+  Array<StmtSRef> loops = GetLoops(block_sref);  // outer loops of the block
+  std::unordered_set<const VarNode*> loop_vars;  // loop vars of the outer 
loops
+  for (const StmtSRef& loop_sref : loops) {
+    CheckLoopStartsWithZero(self, loop_sref, &analyzer);
+    loop_vars.emplace(loop_sref->StmtAs<ForNode>()->loop_var.get());
+  }
+
+  // Step 2: Check the all outer loops have a single child and the block 
bindings are trivial (all
+  // binding values are loop vars)
+  StmtSRef scope_sref{nullptr};  // the scope statement for replacement
+  if (!loops.empty()) {
+    scope_sref = loops.front();
+    CheckGetSingleChildBlockRealizeOnSRefTree(self, loops.front());
+  } else {
+    scope_sref = block_sref;
+  }
+
+  BlockRealize block_realize = GetBlockRealize(self, block_sref);
+  NotTrivialBindingError::CheckBlockHasTrivialBinding(self->mod, 
block_realize, loop_vars);
+
+  // Step 3: Collect information of block iter vars
+  Array<PrimExpr> block_vars;      // iter_var->var of each block iter
+  Map<Var, Range> block_iter_dom;  // domain of block iter
+  std::unordered_map<const VarNode*, IterVarType> block_iter_type;  // iter 
type of block iter
+
+  Array<PrimExpr>
+      block_iter_range_array;  // array of block iter extents in the same 
order as block iters
+  for (const auto& iter_var : block->iter_vars) {
+    block_vars.push_back(iter_var->var);
+    block_iter_dom.Set(iter_var->var, iter_var->dom);
+    block_iter_type[iter_var->var.get()] = iter_var->iter_type;
+    ICHECK(is_zero(iter_var->dom->min));
+    block_iter_range_array.push_back(iter_var->dom->extent);
+  }
+
+  // Step 4: Apply the IndexMap to block iters.
+  IndexMapNotApplicableToBlockIterError::Check(self->mod, block, index_map);
+  Array<PrimExpr> transformed_block_iters = index_map->MapIndices(block_vars);
+  Array<PrimExpr> new_block_iter_range = 
index_map->MapShape(block_iter_range_array);
+
+  auto iter_map = arith::DetectIterMap(
+      /*indices=*/transformed_block_iters, /*input_iters=*/block_iter_dom, 
/*predicate=*/Bool(true),
+      /*require_bijective=*/true, &analyzer, 
/*simplify_trivial_iterators=*/true);
+  if (iter_map.empty()) {
+    throw NotBijectiveAffineIndexMapError(self->mod, index_map);
+  }
+
+  // Step 5: Create the new block after transformation.
+
+  // Step 5.1: Create new block iters. After applying the IndexMap f to block 
iters ax_0, ..., ax_n,
+  // create block iter each expression in f(ax_0, ..., ax_n).
+  Array<IterVar> new_block_iters;  // new block iters
+  Array<PrimExpr> new_block_vars;  // iter_var->var of new block iters
+  for (size_t i = 0; i < index_map->final_indices.size(); ++i) {
+    Var new_block_var{"v" + std::to_string(i), DataType::Int(32)};
+    new_block_vars.push_back(new_block_var);
+    IterVarType iter_type = DetectNewBlockIterType(transformed_block_iters[i], 
block_iter_type);
+    if (iter_type == kOpaque) {
+      throw OpaqueNewIterTypeError(self->mod, GetRef<Block>(block_ptr), 
transformed_block_iters[i]);
+    }
+    new_block_iters.push_back(IterVar(/*dom=*/Range::FromMinExtent(0, 
new_block_iter_range[i]),
+                                      /*var=*/std::move(new_block_var), 
/*iter_type=*/iter_type));
+  }
+
+  // Step 5.2: Update the block body. Use the inverse map f^{-1} to replace 
the original block iters
+  // in the body.
+
+  auto inverse_map = arith::InverseAffineIterMap(iter_map, new_block_vars);
+  // Trivial block iters will be simplified in DetectIterMap, they should be 
mapped to constant
+  // zero.
+  for (const auto& iter_var : block_ptr->iter_vars) {
+    if (inverse_map.find(iter_var->var) == inverse_map.end()) {
+      ICHECK(is_one(iter_var->dom->extent));
+      inverse_map.Set(iter_var->var, 0);
+    }
+  }
+
+  Block new_block = Downcast<Block>(Substitute(GetRef<Block>(block_ptr), 
inverse_map));
+  new_block.CopyOnWrite()->iter_vars = new_block_iters;
+  new_block = Downcast<Block>(BlockBufferAccessSimplifier::Simplify(new_block, 
&analyzer));
+
+  // Step 5.3: Create outer loops for each new block iter.
+
+  // Make new loop vars
+  Array<PrimExpr> new_loop_vars;
+  for (int i = 0; i < static_cast<int>(new_block_iters.size()); ++i) {
+    new_loop_vars.push_back(Var("ax" + std::to_string(i), DataType::Int(32)));
+  }
+
+  // Make new block realize
+  BlockRealizeNode* new_block_realize = block_realize.CopyOnWrite();
+  new_block_realize->iter_values = new_loop_vars;
+  new_block_realize->block = new_block;
+
+  // Generate outer loops
+  Stmt body = GetRef<Stmt>(new_block_realize);
+  for (int i = static_cast<int>(new_loop_vars.size()) - 1; i >= 0; --i) {
+    body = For(Downcast<Var>(new_loop_vars[i]), 0, new_block_iter_range[i], 
ForKind::kSerial,
+               std::move(body));
+  }
+
+  // Step 6: Do the actual replacement
+  self->Replace(scope_sref, body, {{block, new_block}});
+}
+
 class BufferAxisSeparatorMutator : private ReplaceBufferMutator {
  public:
   static Block Mutate(const Block& scope_block, const Buffer& old_buffer, 
Buffer new_buffer,
@@ -270,6 +533,7 @@ void SetAxisSeparator(ScheduleState self, const StmtSRef& 
block_sref, int buffer
   // Step 4: Replace the scope block with the new block
   self->Replace(scope_sref, new_scope_block, block_sref_reuse);
 }
+
 /******** InstructionKind Registration ********/
 
 struct TransformLayoutTraits : public 
UnpackedInstTraits<TransformLayoutTraits> {
@@ -324,6 +588,45 @@ struct TransformLayoutTraits : public 
UnpackedInstTraits<TransformLayoutTraits>
   friend struct ::tvm::tir::UnpackedInstTraits;
 };
 
+struct TransformBlockLayoutTraits : public 
UnpackedInstTraits<TransformBlockLayoutTraits> {
+  static constexpr const char* kName = "TransformBlockLayout";
+  static constexpr bool kIsPure = false;
+
+ private:
+  static constexpr size_t kNumInputs = 1;
+  static constexpr size_t kNumAttrs = 1;
+  static constexpr size_t kNumDecisions = 0;
+
+  static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, IndexMap 
index_map) {
+    return sch->TransformBlockLayout(block_rv, index_map);
+  }
+
+  static String UnpackedAsPython(Array<String> outputs, String block_rv, 
IndexMap index_map) {
+    PythonAPICall py("transform_block_layout");
+    py.Input("block", block_rv);
+    py.Input("index_map", index_map->ToPythonString());
+    return py.Str();
+  }
+
+ public:
+  static ObjectRef AttrsAsJSON(const Array<ObjectRef>& attrs) {
+    Array<ObjectRef> attrs_record;
+    attrs_record.reserve(kNumAttrs);
+    attrs_record.push_back(String(::tvm::SaveJSON(attrs[0])));
+    return std::move(attrs_record);
+  }
+
+  static Array<ObjectRef> AttrsFromJSON(const ObjectRef& attrs_record_) {
+    Array<ObjectRef> attrs_record = Downcast<Array<ObjectRef>>(attrs_record_);
+    Array<ObjectRef> attrs;
+    attrs.push_back(::tvm::LoadJSON(Downcast<String>(attrs_record[0])));
+    return attrs;
+  }
+
+  template <typename>
+  friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
 struct SetAxisSeparatorTraits : public 
UnpackedInstTraits<SetAxisSeparatorTraits> {
   static constexpr const char* kName = "SetAxisSeparator";
   static constexpr bool kIsPure = false;
@@ -359,6 +662,7 @@ struct SetAxisSeparatorTraits : public 
UnpackedInstTraits<SetAxisSeparatorTraits
 };
 
 TVM_REGISTER_INST_KIND_TRAITS(TransformLayoutTraits);
+TVM_REGISTER_INST_KIND_TRAITS(TransformBlockLayoutTraits);
 TVM_REGISTER_INST_KIND_TRAITS(SetAxisSeparatorTraits);
 
 }  // namespace tir
diff --git a/src/tir/schedule/primitive/loop_transformation.cc 
b/src/tir/schedule/primitive/loop_transformation.cc
index d64a72ed34..dbe6a3bbc0 100644
--- a/src/tir/schedule/primitive/loop_transformation.cc
+++ b/src/tir/schedule/primitive/loop_transformation.cc
@@ -250,25 +250,6 @@ class NotOnlyChildError : public ScheduleError {
   For inner_;
 };
 
-class LoopNotStartWithZeroError : public ScheduleError {
- public:
-  explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), 
loop_(std::move(loop)) {}
-
-  String FastErrorString() const final {
-    return "ScheduleError: The primitive only supports loop starting with 0";
-  }
-
-  String DetailRenderTemplate() const final {
-    return "The loop {0} does not start with 0, which is not supported";
-  }
-
-  IRModule mod() const final { return mod_; }
-  Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }
-
-  IRModule mod_;
-  For loop_;
-};
-
 class NotSingleInferFactorError : public ScheduleError {
  public:
   explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {}
@@ -407,10 +388,8 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& 
loop_sref,
   }
   // Currently, loops not starting with 0 are not supported
   arith::Analyzer analyzer;
-  if (!analyzer.CanProve(loop->min == 0)) {
-    throw LoopNotStartWithZeroError(self->mod, GetRef<For>(loop));
-  }
-  // Step 2. Replace all occurrences of the original loop var with new 
variables
+  CheckLoopStartsWithZero(self, loop_sref, &analyzer);
+
   int n = factors.size();
   PrimExpr substitute_value = 0;
   std::vector<Var> new_loop_vars;
@@ -482,9 +461,7 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& 
loop_srefs) {
     }
     outer_loop_sref = sref;
     outer_loop = loop;
-    if (!analyzer.CanProve(loop->min == 0)) {
-      throw LoopNotStartWithZeroError(self->mod, GetRef<For>(loop));
-    }
+    CheckLoopStartsWithZero(self, sref, &analyzer);
     const VarNode* used_var = nullptr;
     auto f_contain = [&outer_loop_vars, &used_var](const VarNode* var) {
       if (outer_loop_vars.count(var)) {
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 8dc0c52111..fb884ce77f 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -233,6 +233,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout")
       return self->TransformLayout(block_rv, buffer_index,
                                    
static_cast<BufferIndexType>(buffer_index_type), index_map);
     });
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformBlockLayout")
+    .set_body_method<Schedule>(&ScheduleNode::TransformBlockLayout);
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator")
     .set_body_typed([](Schedule self, const BlockRV& block_rv, int 
buffer_index,
                        int buffer_index_type, const Array<IntImm>& 
axis_separators) {
diff --git a/src/tir/schedule/traced_schedule.cc 
b/src/tir/schedule/traced_schedule.cc
index 865b6f3784..8156480a45 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -442,6 +442,16 @@ void TracedScheduleNode::TransformLayout(const BlockRV& 
block_rv, int buffer_ind
                            /*outputs=*/{}));
 }
 
+void TracedScheduleNode::TransformBlockLayout(const BlockRV& block_rv, const 
IndexMap& index_map) {
+  ConcreteScheduleNode::TransformBlockLayout(block_rv, index_map);
+  static const InstructionKind& kind = 
InstructionKind::Get("TransformBlockLayout");
+  trace_->Append(
+      /*inst=*/Instruction(/*kind=*/kind,
+                           /*inputs=*/{block_rv},
+                           /*attrs=*/{index_map},
+                           /*outputs=*/{}));
+}
+
 void TracedScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int 
buffer_index,
                                           BufferIndexType buffer_index_type,
                                           const Array<IntImm>& 
axis_separators) {
diff --git a/src/tir/schedule/traced_schedule.h 
b/src/tir/schedule/traced_schedule.h
index 12c076d886..d1860be951 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -98,6 +98,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
   /******** Schedule: Layout transformation ********/
   void TransformLayout(const BlockRV& block_rv, int buffer_index, 
BufferIndexType buffer_index_type,
                        const IndexMap& index_map) override;
+  void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& 
index_map) override;
   void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
                         BufferIndexType buffer_index_type,
                         const Array<IntImm>& axis_separators) final;
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index 6c4f3e1b7a..79802ecd65 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -280,5 +280,36 @@ Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& 
sch, const tir::Block
 
 
TVM_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWithTensorIntrin);
 
+/******** BlockBufferAccessSimplifier ********/
+void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array<BufferRegion>* 
old_access_regions) {
+  auto fmutate = [this](const BufferRegion& buffer_region) {
+    std::vector<Range> new_buffer_region;
+    for (const auto& range : buffer_region->region) {
+      
new_buffer_region.push_back(Range::FromMinExtent(analyzer_->Simplify(range->min),
+                                                       
analyzer_->Simplify(range->extent)));
+    }
+    return BufferRegion(buffer_region->buffer, new_buffer_region);
+  };
+  (*old_access_regions).MutateByApply(fmutate);
+}
+
+Stmt BlockBufferAccessSimplifier::VisitStmt_(const BlockNode* op) {
+  Block block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
+  auto* n = block.CopyOnWrite();
+  SimplifyAccessRegion(&n->reads);
+  SimplifyAccessRegion(&n->writes);
+  return std::move(block);
+}
+
+Stmt BlockBufferAccessSimplifier::VisitStmt_(const BufferStoreNode* op) {
+  auto node = 
Downcast<BufferStore>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
+  return VisitBufferAccess(std::move(node));
+}
+
+PrimExpr BlockBufferAccessSimplifier::VisitExpr_(const BufferLoadNode* op) {
+  auto node = 
Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op));
+  return VisitBufferAccess(std::move(node));
+}
+
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h
index 52e27350d4..192d44d9e9 100644
--- a/src/tir/schedule/transform.h
+++ b/src/tir/schedule/transform.h
@@ -26,6 +26,7 @@
 #include <unordered_map>
 #include <utility>
 
+#include "../../arith/ir_mutator_with_analyzer.h"
 #include "../ir/functor_common.h"
 
 namespace tvm {
@@ -172,6 +173,44 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const 
StmtSRef& leaf_block_
 Optional<tir::LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const 
tir::BlockRV& block_rv,
                                            const String& intrin_name);
 
+/******** Block mutation ********/
+
+/*!
+ * \brief Simplifier for indices of buffer access and block buffer access 
regions.
+ */
+class BlockBufferAccessSimplifier : public arith::IRMutatorWithAnalyzer {
+ public:
+  /*!
+   * \brief Simplify indices of buffer access and block buffer access regions 
in the statement
+   * \param stmt The statement to be simplified
+   * \param analyzer The arithmetic analyzer
+   * \return The simplified statement
+   */
+  static Stmt Simplify(const Stmt& stmt, arith::Analyzer* analyzer) {
+    BlockBufferAccessSimplifier simplifier(analyzer);
+    return simplifier(stmt);
+  }
+
+ private:
+  explicit BlockBufferAccessSimplifier(arith::Analyzer* analyzer)
+      : IRMutatorWithAnalyzer(analyzer) {}
+
+  using IRMutatorWithAnalyzer::VisitExpr_;
+  using IRMutatorWithAnalyzer::VisitStmt_;
+
+  void SimplifyAccessRegion(Array<BufferRegion>* old_access_regions);
+  Stmt VisitStmt_(const BlockNode* op) final;
+  Stmt VisitStmt_(const BufferStoreNode* op) final;
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final;
+
+  template <typename Node>
+  Node VisitBufferAccess(Node node) {
+    node.CopyOnWrite()->indices.MutateByApply(
+        [this](const PrimExpr& expr) { return analyzer_->Simplify(expr); });
+    return node;
+  }
+};
+
 }  // namespace tir
 }  // namespace tvm
 
diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py 
b/tests/python/unittest/test_tir_schedule_transform_layout.py
index 699eaf1236..e184bc3f62 100644
--- a/tests/python/unittest/test_tir_schedule_transform_layout.py
+++ b/tests/python/unittest/test_tir_schedule_transform_layout.py
@@ -91,6 +91,83 @@ def two_elementwise_transformed_output_buffer(
             C[vi // 16, vj // 16, vi % 16, vj % 16] = B[vi, vj] + 1.0
 
 
[email protected]_func
+def elementwise(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), 
"float32"]) -> None:
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B[vi, vj] = A[vi, vj] * 2.0
+
+
[email protected]_func
+def elementwise_transformed(A: T.Buffer[(128, 128), "float32"], B: 
T.Buffer[(128, 128), "float32"]) -> None:
+    for i in range(16384):
+        with T.block("B"):
+            vi, = T.axis.remap("S", [i])
+            B[vi // 128, vi % 128] = A[vi // 128, vi % 128] * 2.0
+
+
[email protected]_func
+def conv2d_nhwc(
+    Input: T.Buffer[(1, 224, 224, 3), "float32"],
+    Weight: T.Buffer[(7, 7, 3, 64), "float32"],
+    Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"],
+) -> None:
+    PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32")
+    for i0, i1, i2, i3 in T.grid(1, 230, 230, 3):
+        with T.block("PadInput"):
+            i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+            PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
+                ((((i1_1 >= 3) and (i1_1 < 227)) and (i2_1 >= 3)) and (i2_1 < 
227)),
+                Input[i0_1, (i1_1 - 3), (i2_1 - 3), i3_1],
+                T.float32(0),
+                dtype="float32",
+            )
+    for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3):
+        with T.block("conv2d_nhwc"):
+            n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, 
i4, i5, i6])
+            with T.init():
+                Conv2d_nhwc[n, h, w, co] = T.float32(0)
+            Conv2d_nhwc[n, h, w, co] = Conv2d_nhwc[n, h, w, co] + (
+                PadInput[n, ((h * 2) + rh), ((w * 2) + rw), ((T.floordiv(co, 
64) * 3) + rc)]
+                * Weight[rh, rw, rc, co]
+            )
+
+
[email protected]_func
+def conv2d_nhwc_transformed(
+    Input: T.Buffer[(1, 224, 224, 3), "float32"],
+    Weight: T.Buffer[(7, 7, 3, 64), "float32"],
+    Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"],
+) -> None:
+    PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32")
+    for i0, i1, i2, i3 in T.grid(1, 230, 230, 3):
+        with T.block("PadInput"):
+            i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+            T.reads(Input[i0_1, i1_1 - 3, i2_1 - 3, i3_1])
+            T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
+            PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
+                i1_1 >= 3 and i1_1 < 227 and i2_1 >= 3 and i2_1 < 227,
+                Input[i0_1, i1_1 - 3, i2_1 - 3, i3_1],
+                T.float32(0),
+                dtype="float32",
+            )
+    for ax0, ax_1, ax_2 in T.grid(12544, 64, 147):
+        with T.block("conv2d_nhwc"):
+            bv0, bv1, bv2 = T.axis.remap("SSR", [ax0, ax_1, ax_2])
+            T.reads(
+                PadInput[0, bv0 // 112 * 2 + bv2 // 21, bv0 % 112 * 2 + bv2 % 
21 // 3, bv2 % 3],
+                Weight[bv2 // 21, bv2 % 21 // 3, bv2 % 3, bv1],
+            )
+            T.writes(Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1])
+            with T.init():
+                Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] = T.float32(0)
+            Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] = (
+                Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1]
+                + PadInput[0, bv0 // 112 * 2 + bv2 // 21, bv0 % 112 * 2 + bv2 
% 21 // 3, bv2 % 3]
+                * Weight[bv2 // 21, bv2 % 21 // 3, bv2 % 3, bv1]
+            )
+
 # pylint: 
enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
 # fmt: on
 
@@ -218,5 +295,41 @@ def test_var_args_sugar():
     tvm.ir.assert_structural_equal(summation_3d_split, sch.mod["main"])
 
 
+def test_transform_block_layout_basic():
+    sch = tir.Schedule(elementwise, debug_mask="all")
+    block = sch.get_block("B")
+    sch.transform_block_layout(block, lambda i, j: (i * 128 + j,))
+    tvm.ir.assert_structural_equal(elementwise_transformed, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=elementwise)
+
+
+def test_transform_block_layout_conv2d_nhwc():
+    sch = tir.Schedule(conv2d_nhwc, debug_mask="all")
+    block = sch.get_block("conv2d_nhwc")
+    sch.transform_block_layout(
+        block,
+        lambda n, h, w, co, rh, rw, rc: (n * 112 * 112 + h * 112 + w, co, rh * 
7 * 3 + rw * 3 + rc),
+    )
+    tvm.ir.assert_structural_equal(conv2d_nhwc_transformed, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc)
+
+
+def test_transform_block_layout_fail_non_affine():
+    sch = tir.Schedule(elementwise, debug_mask="all")
+    block = sch.get_block("B")
+    with pytest.raises(tir.ScheduleError):
+        sch.transform_block_layout(block, lambda i, j: (i + j,))
+
+
+def test_transform_block_layout_fail_mixed_iter_type():
+    sch = tir.Schedule(conv2d_nhwc, debug_mask="all")
+    block = sch.get_block("conv2d_nhwc")
+    with pytest.raises(tir.ScheduleError):
+        sch.transform_block_layout(
+            block,
+            lambda n, h, w, co, rh, rw, rc: (n * 112 * 112 + h * 112 + w, co * 
7 + rh, rw * 3 + rc),
+        )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to