This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0e97186957 [MetaSchedule] Enhance AutoInline for Spatial Task (#11996)
0e97186957 is described below

commit 0e971869575df7e5b12381e4566a1a8fd98a4a77
Author: Junru Shao <[email protected]>
AuthorDate: Sat Jul 2 04:16:25 2022 -0700

    [MetaSchedule] Enhance AutoInline for Spatial Task (#11996)
    
    Previously, Auto-Inline on CPU will only inline according to strict
    conditions, for example, ordered index mapping. This is generally good
    practice to do so, but on the other hand, there is no much benefit to
    stop inlining only due to some restrictive conditions for pure spatial
    subgraphs. By doing so, we also save some search trials on pure spatial
    subgraphs so that more can be allocated to more important ones.
---
 src/meta_schedule/schedule_rule/auto_inline.cc     | 16 +++-
 ...test_meta_schedule_schedule_rule_auto_inline.py | 93 ++++++++++++++++++++++
 2 files changed, 106 insertions(+), 3 deletions(-)

diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc 
b/src/meta_schedule/schedule_rule/auto_inline.cc
index 0cfe35298d..309f0a60ac 100644
--- a/src/meta_schedule/schedule_rule/auto_inline.cc
+++ b/src/meta_schedule/schedule_rule/auto_inline.cc
@@ -31,6 +31,15 @@ enum class InlineType : int32_t {
   kInlineIntoProducer = 2,
 };
 
+bool IsInSpatialPrimFunc(const tir::Schedule& sch, const tir::StmtSRef& 
block_sref) {
+  using namespace tvm::tir;
+  const StmtSRefNode* sref = block_sref.get();
+  for (; sref->parent != nullptr; sref = sref->parent) {
+  }
+  ICHECK(sref->stmt != nullptr && sref->stmt->IsInstance<BlockNode>());
+  return IsSpatialPrimFunc(GetRef<PrimFunc>(GetRootPrimFunc(sch->mod(), 
sref->stmt, nullptr)));
+}
+
 /*! \brief The rule that inlines spatial blocks if it satisfies some 
conditions. */
 class AutoInlineNode : public ScheduleRuleNode {
  public:
@@ -85,6 +94,7 @@ inline InlineType AutoInlineNode::CheckInline(const 
tir::Schedule& sch,
                                               const tir::BlockRV& block_rv) {
   using namespace tvm::tir;
   StmtSRef block_sref = sch->GetSRef(block_rv);
+  bool is_pure_sptial = IsInSpatialPrimFunc(sch, block_sref);
   ScheduleState state = sch->state();
   const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
   BlockRealize realize = GetBlockRealize(state, block_sref);
@@ -97,15 +107,15 @@ inline InlineType AutoInlineNode::CheckInline(const 
tir::Schedule& sch,
     return InlineType::kInlineIntoConsumer;
   }
   // Cond 3. The block doesn't contain any disallowed operators
-  if (!disallow_op.empty() && HasOp(realize, disallow_op)) {
+  if (!is_pure_sptial && !disallow_op.empty() && HasOp(realize, disallow_op)) {
     return InlineType::kNoInline;
   }
   // Cond 4. The block doesn't have any if-then-else-like constructs
-  if (disallow_if_then_else && HasIfThenElse(realize)) {
+  if (!is_pure_sptial && disallow_if_then_else && HasIfThenElse(realize)) {
     return InlineType::kNoInline;
   }
   // Cond 5. The mapping from read indices to write indices are injective and 
ordered
-  if (require_injective || require_ordered) {
+  if (!is_pure_sptial && (require_injective || require_ordered)) {
     const BufferRegion& write_region = block->writes[0];
     for (const BufferRegion& read_region : block->reads) {
       bool injective, ordered;
diff --git 
a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py 
b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
index 2a8a1e5fe1..a8ffa6ff9d 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
@@ -240,6 +240,86 @@ class SoftmaxAfterInline:
                 T_softmax_norm[i0_4, i1_1] = T.exp(A[i0_4, i1_1] - 
T_softmax_maxelem[i0_4], dtype="float32") / T_softmax_expsum[i0_4]
 
 
[email protected]_module
+class BeforePureSpatial:
+    @T.prim_func
+    def main(
+        placeholder: T.Buffer[(1, 384), "int64"],
+        placeholder_1: T.Buffer[(30522, 768), "float32"],
+        placeholder_2: T.Buffer[(1, 384, 768), "float32"],
+        T_add: T.Buffer[(1, 384, 768), "float32"],
+    ) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        compile_engine_const = T.alloc_buffer([], dtype="int64")
+        T_less = T.alloc_buffer([1, 384], dtype="bool")
+        compile_engine_const_1 = T.alloc_buffer([], dtype="int64")
+        T_add_1 = T.alloc_buffer([1, 384], dtype="int64")
+        T_where = T.alloc_buffer([1, 384], dtype="int64")
+        T_take = T.alloc_buffer([1, 384, 768], dtype="float32")
+        with T.block("compile_engine_const"):
+            vi = T.axis.spatial(1, 0)
+            T.reads()
+            T.writes(compile_engine_const[()])
+            compile_engine_const[()] = T.int64(0)
+        for i0, i1 in T.grid(1, 384):
+            with T.block("T_less"):
+                ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                T.reads(placeholder[ax0, ax1], compile_engine_const[()])
+                T.writes(T_less[ax0, ax1])
+                T_less[ax0, ax1] = placeholder[ax0, ax1] < 
compile_engine_const[()]
+        with T.block("compile_engine_const_1"):
+            vi = T.axis.spatial(1, 0)
+            T.reads()
+            T.writes(compile_engine_const_1[()])
+            compile_engine_const_1[()] = T.int64(30522)
+        for i0, i1 in T.grid(1, 384):
+            with T.block("T_add"):
+                ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                T.reads(placeholder[ax0, ax1], compile_engine_const_1[()])
+                T.writes(T_add_1[ax0, ax1])
+                T_add_1[ax0, ax1] = placeholder[ax0, ax1] + 
compile_engine_const_1[()]
+        for i0, i1 in T.grid(1, 384):
+            with T.block("T_where"):
+                ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                T.reads(T_less[ax0, ax1], T_add_1[ax0, ax1], placeholder[ax0, 
ax1])
+                T.writes(T_where[ax0, ax1])
+                T_where[ax0, ax1] = T.Select(
+                    T.cast(T_less[ax0, ax1], "int32") != 0, T_add_1[ax0, ax1], 
placeholder[ax0, ax1]
+                )
+        for i0, i1, i2 in T.grid(1, 384, 768):
+            with T.block("T_take"):
+                ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
+                T.reads(
+                    placeholder_1[T.min(T.max(T.int64(0), T_where[ax0, ax1]), 
T.int64(30521)), ax2],
+                    T_where[ax0, ax1],
+                )
+                T.writes(T_take[ax0, ax1, ax2])
+                T_take[ax0, ax1, ax2] = placeholder_1[
+                    T.min(T.max(T.int64(0), T_where[ax0, ax1]), 
T.int64(30521)), ax2
+                ]
+        for i0, i1, i2 in T.grid(1, 384, 768):
+            with T.block("T_add_1"):
+                ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
+                T.reads(T_take[ax0, ax1, ax2], placeholder_2[ax0, ax1, ax2])
+                T.writes(T_add[ax0, ax1, ax2])
+                T_add[ax0, ax1, ax2] = T_take[ax0, ax1, ax2] + 
placeholder_2[ax0, ax1, ax2]
+
+
[email protected]_module
+class AfterPureSpatial:
+    @T.prim_func
+    def main(placeholder: T.Buffer[(1, 384), "int64"], placeholder_1: 
T.Buffer[(30522, 768), "float32"], placeholder_2: T.Buffer[(1, 384, 768), 
"float32"], T_add: T.Buffer[(1, 384, 768), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        for i0, i1, i2 in T.grid(1, 384, 768):
+            with T.block("T_add_1"):
+                ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
+                T.reads(placeholder[ax0, ax1], 
placeholder_1[T.min(T.max(T.int64(0), placeholder[ax0, ax1]), T.int64(30521)) : 
T.min(T.max(T.int64(0), placeholder[ax0, ax1] + T.int64(30522)), 
T.int64(30521)) + T.int64(1), ax2], placeholder_2[ax0, ax1, ax2])
+                T.writes(T_add[ax0, ax1, ax2])
+                T_add[ax0, ax1, ax2] = placeholder_1[T.min(T.max(T.int64(0), 
T.Select(T.cast(placeholder[ax0, ax1] < T.int64(0), "int32") != 0, 
placeholder[ax0, ax1] + T.int64(30522), placeholder[ax0, ax1])), 
T.int64(30521)), ax2] + placeholder_2[ax0, ax1, ax2]
+
 # pylint: 
enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
 # fmt: on
 
@@ -291,7 +371,20 @@ def test_inline_into_multiple_consumers():
     tvm.ir.assert_structural_equal(lhs=space.mod, rhs=SoftmaxAfterInline)
 
 
+def test_inline_pure_spatial():
+    mod = BeforePureSpatial
+    target = Target("llvm")
+    ctx = _create_context(
+        mod=mod,
+        target=target,
+        rule=auto_inline(target=target),
+    )
+    (space,) = ctx.space_generator.generate_design_space(mod=mod)
+    tvm.ir.assert_structural_equal(lhs=space.mod, rhs=AfterPureSpatial)
+
+
 if __name__ == "__main__":
     test_inline_consumer_chain()
     test_inline_into_cache()
     test_inline_into_multiple_consumers()
+    test_inline_pure_spatial()

Reply via email to