This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new df429c58d8 [TIR] Allow TransformLayout with non-inversible index map 
(#14095)
df429c58d8 is described below

commit df429c58d833bbe02ceafb69af1c29c7896218b7
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Mar 3 15:00:15 2023 -0800

    [TIR] Allow TransformLayout with non-inversible index map (#14095)
    
    * [TIR] Allow TransformLayout with non-inversible index map
    
    TransformLayout requires the index map to have inverse map that can be
    calculated by the analyzer in order to check whether padding is added.
    However, such check doesn't always work for all cases because of
    limitation of the affine analysis that can only handle a set of
    supported patterns. In some cases, even if the index map doesn't
    introduce padding, the schedule primitive throws 
`TransformationIntroducesPaddingError` because it
    fails to calculate the inverse index map.
    
    It is safe to allow buffer being padded without providing pad_value
    because the original loop extent is not changed and the padded region is 
not accessed.
    This PR changes the behavior of `TransformLayout` to allow
    non-inversible index map.
    
    Previous discussion:
    
https://discuss.tvm.apache.org/t/conflict-free-shared-memory-permutation-in-tensorir/13959/9
    
    * add assume_injective_transform option
    
    * Apply suggestions from code review
    
    Co-authored-by: Siyuan Feng <[email protected]>
    
    ---------
    
    Co-authored-by: Siyuan Feng <[email protected]>
---
 include/tvm/tir/schedule/schedule.h                |  9 ++++-
 python/tvm/tir/schedule/schedule.py                | 18 ++++++++-
 src/tir/schedule/concrete_schedule.cc              |  5 ++-
 src/tir/schedule/concrete_schedule.h               |  3 +-
 src/tir/schedule/primitive.h                       |  7 +++-
 .../schedule/primitive/layout_transformation.cc    | 44 +++++++++++++---------
 src/tir/schedule/schedule.cc                       |  4 +-
 src/tir/schedule/traced_schedule.cc                |  9 +++--
 src/tir/schedule/traced_schedule.h                 |  3 +-
 .../unittest/test_tir_schedule_transform_layout.py | 34 ++++++++++++++++-
 10 files changed, 104 insertions(+), 32 deletions(-)

diff --git a/include/tvm/tir/schedule/schedule.h 
b/include/tvm/tir/schedule/schedule.h
index 288601d1cc..7f2bdf6b4e 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -642,10 +642,17 @@ class ScheduleNode : public runtime::Object {
    *    Algebraic symplifications, branch elimination, and other
    *    optimizations may assume that this precondition is met, and
    *    may result in incorrect results being returned.
+   *
+   * \param assume_injective_transform If set to true, the schedule primitive 
will assume the
+   * index_map is injective and skip checking overlapping of the mapped 
indices. This can be useful
+   * for complicated index_map that the analysis does not cover. It is the 
callers' responsibility
+   * to ensure the index map is injective, otherwise, the correctness of the 
schedule is not
+   * guaranteed.
    */
   virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
                                BufferIndexType buffer_index_type, const 
IndexMap& index_map,
-                               const Optional<IndexMap>& pad_value = NullOpt) 
= 0;
+                               const Optional<IndexMap>& pad_value = NullOpt,
+                               bool assume_injective_transform = false) = 0;
 
   /*!
    * \brief Apply a transformation represented by IndexMap to block
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index 04355eb16e..b63353bcb3 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2575,7 +2575,6 @@ class Schedule(Object):
         buffer: Union[Tuple[str, int], int, str, Buffer],
         required_buffer_type=None,
     ) -> Tuple[str, int, Buffer]:
-
         block_obj: Block = self.get(block)
         block_name = block_obj.name_hint
 
@@ -2645,6 +2644,8 @@ class Schedule(Object):
         buffer: Union[Tuple[str, int], str, Buffer],
         index_map: Union[IndexMap, Callable],
         pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]] = 
None,
+        *,
+        assume_injective_transform: bool = False,
     ) -> None:
         """Apply a transformation represented by IndexMap to buffer
 
@@ -2711,6 +2712,13 @@ class Schedule(Object):
             value to be present in the padding in terms of the
             transformed index.
 
+        assume_injective_transform : bool
+
+            If set to true, the schedule  primitive will assume the index_map 
is injective and skip
+            checking overlapping of the mapped indices. This can be useful for 
complicated index_map
+            that the analysis does not cover. It is the callers' 
responsibility to ensure the
+            index map is injective, otherwise, the correctness of the schedule 
is not guaranteed.
+
         Examples
         --------
         Before transform_layout, in TensorIR, the IR is:
@@ -2787,7 +2795,13 @@ class Schedule(Object):
 
         buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
         _ffi_api.ScheduleTransformLayout(  # type: ignore # pylint: 
disable=no-member
-            self, block, buffer_index, buffer_index_type_enum, index_map, 
pad_value
+            self,
+            block,
+            buffer_index,
+            buffer_index_type_enum,
+            index_map,
+            pad_value,
+            assume_injective_transform,
         )
         if axis_separators:
             _ffi_api.ScheduleSetAxisSeparator(  # type: ignore # pylint: 
disable=no-member
diff --git a/src/tir/schedule/concrete_schedule.cc 
b/src/tir/schedule/concrete_schedule.cc
index b6af22263e..8af39b24fd 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -800,14 +800,15 @@ void ConcreteScheduleNode::Unannotate(const BlockRV& 
block_rv, const String& ann
 void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int 
buffer_index,
                                            BufferIndexType buffer_index_type,
                                            const IndexMap& index_map,
-                                           const Optional<IndexMap>& 
pad_value) {
+                                           const Optional<IndexMap>& pad_value,
+                                           bool assume_injective_transform) {
   TVM_TIR_SCHEDULE_BEGIN();
   auto f_subst = [&](const Var& var) -> Optional<PrimExpr> {
     return Downcast<Optional<PrimExpr>>(symbol_table_.Get(var));
   };
   auto new_index_map = Substitute(index_map, f_subst);
   tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, 
buffer_index_type,
-                       new_index_map, pad_value);
+                       new_index_map, pad_value, assume_injective_transform);
   this->state_->DebugVerify();
   TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_);
 }
diff --git a/src/tir/schedule/concrete_schedule.h 
b/src/tir/schedule/concrete_schedule.h
index 44d9e9b69c..41168fb016 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -148,7 +148,8 @@ class ConcreteScheduleNode : public ScheduleNode {
   void Unannotate(const BlockRV& block_rv, const String& ann_key) override;
   /******** Schedule: Layout transformation ********/
   void TransformLayout(const BlockRV& block_rv, int buffer_index, 
BufferIndexType buffer_index_type,
-                       const IndexMap& index_map, const Optional<IndexMap>& 
pad_value) override;
+                       const IndexMap& index_map, const Optional<IndexMap>& 
pad_value,
+                       bool assume_injective_transform = false) override;
   void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& 
index_map) override;
   void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
                         BufferIndexType buffer_index_type,
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index dbc4e23596..0b7a4f6280 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -501,10 +501,15 @@ TVM_DLL void Unannotate(ScheduleState self, const 
StmtSRef& sref, const String&
  * \param buffer_index_type The type of the buffer index, kRead or kWrite.
  * \param index_map The transformation to apply.
  * \param pad_value The value to write into padding introduced by the 
transformation.
+ * \param assume_injective_transform If set to true, the schedule primitive 
will assume the
+ * index_map is injective and skip checking overlapping of the mapped indices. 
This can be useful
+ * for complicated index_map that the analysis does not cover. It is the 
callers' responsibility
+ * to ensure the index map is injective, otherwise, the correctness of the 
schedule is not
+ * guaranteed.
  */
 TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, 
int buffer_index,
                              BufferIndexType buffer_index_type, const 
IndexMap& index_map,
-                             const Optional<IndexMap>& pad_value);
+                             const Optional<IndexMap>& pad_value, bool 
assume_injective_transform);
 
 /*!
  * \brief Apply a transformation represented by IndexMap to block
diff --git a/src/tir/schedule/primitive/layout_transformation.cc 
b/src/tir/schedule/primitive/layout_transformation.cc
index 0e993d06dc..7eaca74100 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -753,10 +753,12 @@ class TransformLayoutRewriter : private 
arith::IRMutatorWithAnalyzer {
    */
   static std::pair<Stmt, Map<Block, Block>> Rewrite(
       const Block& scope_stmt, const Buffer& old_buffer, const Buffer& 
new_buffer,
-      const IndexMap& index_map, const IndexMap& inverse, const PrimExpr& 
padding_predicate,
-      const Optional<IndexMap>& pad_value) {
-    auto plan = TransformLayoutPlanner::Plan(scope_stmt, old_buffer, 
new_buffer, index_map, inverse,
-                                             padding_predicate, pad_value);
+      const IndexMap& index_map, const Optional<IndexMap>& opt_inverse,
+      const PrimExpr& padding_predicate, const Optional<IndexMap>& pad_value) {
+    auto plan = pad_value.defined() ? TransformLayoutPlanner::Plan(
+                                          scope_stmt, old_buffer, new_buffer, 
index_map,
+                                          opt_inverse.value(), 
padding_predicate, pad_value)
+                                    : 
TransformLayoutPlanner::NoPaddingRequired();
 
     arith::Analyzer analyzer;
     TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan, 
&analyzer);
@@ -1119,7 +1121,7 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, 
const Array<PrimExpr>&
 
 void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int 
buffer_index,
                      BufferIndexType buffer_index_type, const IndexMap& 
index_map_orig,
-                     const Optional<IndexMap>& pad_value) {
+                     const Optional<IndexMap>& pad_value, bool 
assume_injective_transform) {
   // Step 1: Input handling and error checking
   const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
   Buffer old_buffer =
@@ -1147,13 +1149,17 @@ void TransformLayout(ScheduleState self, const 
StmtSRef& block_sref, int buffer_
                             : GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/false);
   const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref);
 
-  auto [inverse, padding_predicate] = [&]() {
-    Array<Range> region;
-    for (const auto& dim : old_buffer->shape) {
-      region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim));
-    }
-    return index_map.NonSurjectiveInverse(region);
-  }();
+  Optional<IndexMap> opt_inverse = NullOpt;
+  PrimExpr padding_predicate = Bool(false);
+  if (!assume_injective_transform) {
+    std::tie(opt_inverse, padding_predicate) = [&]() {
+      Array<Range> region;
+      for (const auto& dim : old_buffer->shape) {
+        region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim));
+      }
+      return index_map.NonSurjectiveInverse(region);
+    }();
+  }
 
   bool has_padding = !is_zero(padding_predicate);
   if (has_padding && !pad_value.defined()) {
@@ -1168,7 +1174,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& 
block_sref, int buffer_
   // alloc_buffers.
   auto [new_stmt, block_sref_reuse] =
       TransformLayoutRewriter::Rewrite(GetRef<Block>(scope_block), old_buffer, 
new_buffer,
-                                       index_map, inverse, padding_predicate, 
pad_value);
+                                       index_map, opt_inverse, 
padding_predicate, pad_value);
   Block new_scope_block = Downcast<Block>(new_stmt);
 
   // Step 4: Rewrite buffer_map of the PrimFunc if necessary.
@@ -1511,20 +1517,21 @@ struct TransformLayoutTraits : public 
UnpackedInstTraits<TransformLayoutTraits>
 
  private:
   static constexpr size_t kNumInputs = 2;
-  static constexpr size_t kNumAttrs = 3;
+  static constexpr size_t kNumAttrs = 4;
   static constexpr size_t kNumDecisions = 0;
 
   static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, IndexMap 
index_map,
                                       Integer buffer_index, Integer 
buffer_index_type,
-                                      Optional<IndexMap> pad_value) {
+                                      Optional<IndexMap> pad_value,
+                                      Bool assume_injective_transform) {
     return sch->TransformLayout(block_rv, buffer_index.IntValue(),
                                 
static_cast<BufferIndexType>(buffer_index_type->value), index_map,
-                                pad_value);
+                                pad_value, assume_injective_transform.operator 
bool());
   }
 
   static String UnpackedAsPython(Array<String> outputs, String block_rv, 
IndexMap index_map,
                                  Integer buffer_index, Integer 
buffer_index_type,
-                                 Optional<IndexMap> pad_value) {
+                                 Optional<IndexMap> pad_value, Bool 
assume_injective_transform) {
     PythonAPICall py("transform_layout");
     py.Input("block", block_rv);
 
@@ -1534,6 +1541,7 @@ struct TransformLayoutTraits : public 
UnpackedInstTraits<TransformLayoutTraits>
     py.Input("buffer", os.str());
     py.Input("index_map", index_map->ToPythonString());
     py.Input("pad_value", pad_value ? pad_value.value()->ToPythonString() : 
"None");
+    py.Input("assume_injective_transform", assume_injective_transform.operator 
bool());
 
     return py.Str();
   }
@@ -1549,6 +1557,7 @@ struct TransformLayoutTraits : public 
UnpackedInstTraits<TransformLayoutTraits>
     } else {
       attrs_record.push_back(attrs[2]);
     }
+    attrs_record.push_back(attrs[3]);
     return std::move(attrs_record);
   }
 
@@ -1562,6 +1571,7 @@ struct TransformLayoutTraits : public 
UnpackedInstTraits<TransformLayoutTraits>
     } else {
       attrs.push_back(attrs_record[2]);
     }
+    attrs.push_back(attrs_record[3]);
     return attrs;
   }
 
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index d008f3639c..4177d91648 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -253,10 +253,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate")
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout")
     .set_body_typed([](Schedule self, const BlockRV& block_rv, int 
buffer_index,
                        int buffer_index_type, const IndexMap& index_map,
-                       const Optional<IndexMap>& pad_value) {
+                       const Optional<IndexMap>& pad_value, bool 
assume_injective_transform) {
       return self->TransformLayout(block_rv, buffer_index,
                                    
static_cast<BufferIndexType>(buffer_index_type), index_map,
-                                   pad_value);
+                                   pad_value, assume_injective_transform);
     });
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformBlockLayout")
     .set_body_method<Schedule>(&ScheduleNode::TransformBlockLayout);
diff --git a/src/tir/schedule/traced_schedule.cc 
b/src/tir/schedule/traced_schedule.cc
index 8852fccf88..dba34c2ca3 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -523,15 +523,18 @@ void TracedScheduleNode::Unannotate(const BlockRV& 
block_rv, const String& ann_k
 void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int 
buffer_index,
                                          BufferIndexType buffer_index_type,
                                          const IndexMap& index_map,
-                                         const Optional<IndexMap>& pad_value) {
+                                         const Optional<IndexMap>& pad_value,
+                                         bool assume_injective_transform) {
   ConcreteScheduleNode::TransformLayout(block_rv, buffer_index, 
buffer_index_type, index_map,
-                                        pad_value);
+                                        pad_value, assume_injective_transform);
   static const InstructionKind& kind = InstructionKind::Get("TransformLayout");
   trace_->Append(
       /*inst=*/Instruction(
           /*kind=*/kind,
           /*inputs=*/{block_rv, index_map},
-          /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type), 
pad_value},
+          /*attrs=*/
+          {Integer(buffer_index), Integer(buffer_index_type), pad_value,
+           Bool(assume_injective_transform)},
           /*outputs=*/{}));
 }
 
diff --git a/src/tir/schedule/traced_schedule.h 
b/src/tir/schedule/traced_schedule.h
index ee65c721ad..7bd8385555 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -107,7 +107,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
   void Unannotate(const BlockRV& block_rv, const String& ann_key) override;
   /******** Schedule: Layout transformation ********/
   void TransformLayout(const BlockRV& block_rv, int buffer_index, 
BufferIndexType buffer_index_type,
-                       const IndexMap& index_map, const Optional<IndexMap>& 
pad_value) override;
+                       const IndexMap& index_map, const Optional<IndexMap>& 
pad_value,
+                       bool assume_injective_transform) override;
   void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& 
index_map) override;
   void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
                         BufferIndexType buffer_index_type,
diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py 
b/tests/python/unittest/test_tir_schedule_transform_layout.py
index d866de33f1..c9a8f70ef7 100644
--- a/tests/python/unittest/test_tir_schedule_transform_layout.py
+++ b/tests/python/unittest/test_tir_schedule_transform_layout.py
@@ -477,11 +477,19 @@ class BasePaddingCompare(tvm.testing.CompareBeforeAfter):
 
     index_map = tvm.testing.parameter(lambda i: [i // 4, i % 4])
 
+    assume_injective_transform = tvm.testing.parameter(False)
+
     @pytest.fixture
-    def transform(self, pad_value, transformed_buffer, index_map):
+    def transform(self, pad_value, transformed_buffer, index_map, 
assume_injective_transform):
         def transform(mod):
             sch = tir.Schedule(mod)
-            sch.transform_layout("block", transformed_buffer, index_map, 
pad_value=pad_value)
+            sch.transform_layout(
+                "block",
+                transformed_buffer,
+                index_map,
+                pad_value=pad_value,
+                assume_injective_transform=assume_injective_transform,
+            )
             return sch.mod
 
         return transform
@@ -578,6 +586,28 @@ class TestErrorIfPaddingForbidden(BasePaddingCompare):
     expected = tvm.tir.schedule.schedule.ScheduleError
 
 
+class TestImplicitPaddingAssumeInjective(BasePaddingCompare):
+    """When pad_value is None and assume_injective_transform is set, the 
buffer can be implicitly
+    padded. The padded region is not accessed because the original loop extent 
is not changed.
+    """
+
+    assume_injective_transform = tvm.testing.parameter(True)
+
+    def before():
+        A = T.alloc_buffer(14, "int32")
+        for i in T.serial(14):
+            with T.block("block"):
+                vi = T.axis.remap("S", [i])
+                A[vi] = 0
+
+    def expected():
+        A = T.alloc_buffer([4, 4], "int32")
+        for i in T.serial(14):
+            with T.block("block"):
+                vi = T.axis.remap("S", [i])
+                A[vi // 4, vi % 4] = 0
+
+
 class TestErrorOnWrongPaddingType(BasePaddingCompare):
     """The padding must have the same dtype as the buffer"""
 

Reply via email to