This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b2a7bb9ee4 [MetaSchedule] Handle output cases for 
InlineConstantScalars (#14654)
b2a7bb9ee4 is described below

commit b2a7bb9ee476230ca2ef6bea3460ca64a5cad8d6
Author: Yixin Dong <[email protected]>
AuthorDate: Thu Apr 20 02:57:29 2023 +0800

    [MetaSchedule] Handle output cases for InlineConstantScalars (#14654)
    
    finished
---
 src/meta_schedule/schedule_rule/auto_inline.cc     |  5 +++-
 ...test_meta_schedule_schedule_rule_auto_inline.py | 29 ++++++++++++++++------
 2 files changed, 26 insertions(+), 8 deletions(-)

diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc 
b/src/meta_schedule/schedule_rule/auto_inline.cc
index 22e8396925..d9e033eff8 100644
--- a/src/meta_schedule/schedule_rule/auto_inline.cc
+++ b/src/meta_schedule/schedule_rule/auto_inline.cc
@@ -209,7 +209,10 @@ class InlineConstantScalarsNode : public ScheduleRuleNode {
     auto block = sch->Get(block_rv);
     if (block->reads.size() == 0 && block->writes.size() == 1 &&
         block->writes[0]->buffer->shape.size() == 0) {
-      sch->ComputeInline(block_rv);
+      auto sref = sch->GetSRef(block_rv);
+      if (!tir::IsOutputBlock(sch->state(), sref, 
tir::GetScopeRoot(sch->state(), sref, true))) {
+        sch->ComputeInline(block_rv);
+      }
     }
     return {sch};
   }
diff --git 
a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py 
b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
index 0b2e7fc086..3f43e0133c 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
@@ -18,11 +18,13 @@
 import pytest
 
 import tvm
-from tvm.tir import Schedule
+import tvm.testing
 from tvm import meta_schedule as ms
+from tvm.ir.base import assert_structural_equal
 from tvm.meta_schedule.testing.space_generation import generate_design_space
 from tvm.script import tir as T
 from tvm.target import Target
+from tvm.tir import Schedule
 
 # fmt: off
 # pylint: 
disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
@@ -512,10 +514,23 @@ def test_conv2d_int8_inline_constant_scalars():
     sch.reverse_compute_inline(sch.get_block("T_add_1"))
 
 
+def test_inline_constant_scalars_skip_output_block():
+    # If the constant scalar block is an output block, it should not be inlined
+
+    @tvm.script.ir_module
+    class Full:
+        @T.prim_func
+        def main(T_full: T.Buffer((), "float32")):
+            with T.block("T_full"):
+                vi = T.axis.spatial(1, 0)
+                T.reads()
+                T.writes(T_full[()])
+                T_full[()] = T.float32(1)
+
+    sch = Schedule(Full)
+    sch = ms.schedule_rule.InlineConstantScalars().apply(sch, 
sch.get_block("T_full"))[0]
+    assert_structural_equal(sch.mod, Full)
+
+
 if __name__ == "__main__":
-    test_inline_consumer_chain()
-    test_inline_into_cache()
-    test_inline_into_multiple_consumers()
-    test_inline_pure_spatial()
-    test_inline_constant_tensor()
-    test_conv2d_int8_inline_constant_scalars()
+    tvm.testing.main()

Reply via email to