This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b2a7bb9ee4 [MetaSchedule] Handle output cases for
InlineConstantScalars (#14654)
b2a7bb9ee4 is described below
commit b2a7bb9ee476230ca2ef6bea3460ca64a5cad8d6
Author: Yixin Dong <[email protected]>
AuthorDate: Thu Apr 20 02:57:29 2023 +0800
[MetaSchedule] Handle output cases for InlineConstantScalars (#14654)
finished
---
src/meta_schedule/schedule_rule/auto_inline.cc | 5 +++-
...test_meta_schedule_schedule_rule_auto_inline.py | 29 ++++++++++++++++------
2 files changed, 26 insertions(+), 8 deletions(-)
diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc
b/src/meta_schedule/schedule_rule/auto_inline.cc
index 22e8396925..d9e033eff8 100644
--- a/src/meta_schedule/schedule_rule/auto_inline.cc
+++ b/src/meta_schedule/schedule_rule/auto_inline.cc
@@ -209,7 +209,10 @@ class InlineConstantScalarsNode : public ScheduleRuleNode {
auto block = sch->Get(block_rv);
if (block->reads.size() == 0 && block->writes.size() == 1 &&
block->writes[0]->buffer->shape.size() == 0) {
- sch->ComputeInline(block_rv);
+ auto sref = sch->GetSRef(block_rv);
+ if (!tir::IsOutputBlock(sch->state(), sref,
tir::GetScopeRoot(sch->state(), sref, true))) {
+ sch->ComputeInline(block_rv);
+ }
}
return {sch};
}
diff --git
a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
index 0b2e7fc086..3f43e0133c 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
@@ -18,11 +18,13 @@
import pytest
import tvm
-from tvm.tir import Schedule
+import tvm.testing
from tvm import meta_schedule as ms
+from tvm.ir.base import assert_structural_equal
from tvm.meta_schedule.testing.space_generation import generate_design_space
from tvm.script import tir as T
from tvm.target import Target
+from tvm.tir import Schedule
# fmt: off
# pylint:
disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
@@ -512,10 +514,23 @@ def test_conv2d_int8_inline_constant_scalars():
sch.reverse_compute_inline(sch.get_block("T_add_1"))
+def test_inline_constant_scalars_skip_output_block():
+ # If the constant scalar block is an output block, it should not be inlined
+
+ @tvm.script.ir_module
+ class Full:
+ @T.prim_func
+ def main(T_full: T.Buffer((), "float32")):
+ with T.block("T_full"):
+ vi = T.axis.spatial(1, 0)
+ T.reads()
+ T.writes(T_full[()])
+ T_full[()] = T.float32(1)
+
+ sch = Schedule(Full)
+ sch = ms.schedule_rule.InlineConstantScalars().apply(sch,
sch.get_block("T_full"))[0]
+ assert_structural_equal(sch.mod, Full)
+
+
if __name__ == "__main__":
- test_inline_consumer_chain()
- test_inline_into_cache()
- test_inline_into_multiple_consumers()
- test_inline_pure_spatial()
- test_inline_constant_tensor()
- test_conv2d_int8_inline_constant_scalars()
+ tvm.testing.main()