This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a23b71ce1e [MetaSchedule][Test] Migrate AddRFactor to SEqual (#12758)
a23b71ce1e is described below

commit a23b71ce1e3011be6b8e6ca5162b023956358911
Author: Junru Shao <[email protected]>
AuthorDate: Mon Sep 12 15:42:40 2022 -0800

    [MetaSchedule][Test] Migrate AddRFactor to SEqual (#12758)
    
    This PR migrates the usage of `check_trace` to `check_sketch`,
    which prefers structural equality of TIRs insteda of string equalty
    of traces.
---
 python/tvm/meta_schedule/testing/schedule_rule.py  |  16 +--
 python/tvm/tir/schedule/testing.py                 |   8 +-
 src/meta_schedule/schedule_rule/add_rfactor.cc     |   5 +-
 src/tir/schedule/primitive/sampling.cc             |   4 +-
 ...test_meta_schedule_schedule_rule_add_rfactor.py | 142 ++++++++++++++-------
 5 files changed, 109 insertions(+), 66 deletions(-)

diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py 
b/python/tvm/meta_schedule/testing/schedule_rule.py
index 46df4b95ce..b08db0811d 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -18,7 +18,6 @@
 from typing import List, Union
 
 from tvm.meta_schedule.schedule_rule import (
-    AddRFactor,
     AutoBind,
     AutoInline,
     CrossThreadReduction,
@@ -28,7 +27,9 @@ from tvm.meta_schedule.schedule_rule import (
     ReuseType,
     ScheduleRule,
 )
-from tvm.meta_schedule.schedule_rule.multi_level_tiling import 
MultiLevelTilingTensorCore
+from tvm.meta_schedule.schedule_rule.multi_level_tiling import (
+    MultiLevelTilingTensorCore,
+)
 from tvm.target import Target
 
 
@@ -64,13 +65,6 @@ def auto_inline(target: Target) -> ScheduleRule:
     raise NotImplementedError(f"{target.kind.name} is not supported")
 
 
-def add_rfactor(target: Target) -> ScheduleRule:
-    """Default schedule rules for with add_rfactor"""
-    if target.kind.name == "llvm":
-        return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64)
-    raise NotImplementedError(f"{target.kind.name} is not supported")
-
-
 def cross_thread_reduction(target: Target) -> ScheduleRule:
     """Default schedule rules for with cross-thread reduction"""
     if target.kind.name == "cuda":
@@ -131,7 +125,9 @@ def multi_level_tiling_tensor_core(
         trans_b = [trans_b]
 
     if target.kind.name == "cuda":
-        from tvm.tir.tensor_intrin import cuda  # pylint: 
disable=import-outside-toplevel
+        from tvm.tir.tensor_intrin import (  # pylint: 
disable=import-outside-toplevel
+            cuda,
+        )
 
         intrin_groups = [
             cuda.get_wmma_intrin_group(write_reuse_scope, _in_dtype, 
_out_dtype, _trans_b)
diff --git a/python/tvm/tir/schedule/testing.py 
b/python/tvm/tir/schedule/testing.py
index 3689f756e8..538cc6e143 100644
--- a/python/tvm/tir/schedule/testing.py
+++ b/python/tvm/tir/schedule/testing.py
@@ -15,12 +15,12 @@
 # specific language governing permissions and limitations
 # under the License.
 """Testing utilities for the TensorIR schedule API"""
-from typing import Union, Sequence
+from typing import Sequence, Union
 
 import tvm
-from tvm.ir import IRModule, structural_equal
+from tvm.ir import IRModule, assert_structural_equal
 from tvm.tir import PrimFunc
-from tvm.tir.schedule import Trace, Schedule
+from tvm.tir.schedule import Schedule, Trace
 
 
 def verify_trace_roundtrip(
@@ -70,7 +70,7 @@ def verify_trace_roundtrip(
         assert text_format in ("json", "python"), f"Unknown text format: 
{text_format}"
 
     # Step 2. Verify that the round-trip produced the same scheduling
-    assert structural_equal(new_sch.mod, sch.mod)
+    assert_structural_equal(new_sch.mod, sch.mod)
 
     # Step 3. Check the consistency of the text format between the old and new 
traces
     py_repr = "\n".join(trace.as_python())
diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc 
b/src/meta_schedule/schedule_rule/add_rfactor.cc
index 5ef2ac3aad..cf87f24ac2 100644
--- a/src/meta_schedule/schedule_rule/add_rfactor.cc
+++ b/src/meta_schedule/schedule_rule/add_rfactor.cc
@@ -90,8 +90,7 @@ Array<tir::Schedule> AddRFactorNode::Apply(const 
tir::Schedule& sch, const tir::
 
   // Split the fused reduction loop.
   Array<tir::ExprRV> factors = sch->SamplePerfectTile(fused_reduce_loop, 2, 
max_innermost_factor);
-  const Array<tir::LoopRV>& split_loops =
-      sch->Split(fused_reduce_loop, {factors.begin(), factors.end()});
+  Array<tir::LoopRV> split_loops = sch->Split(fused_reduce_loop, 
{factors.begin(), factors.end()});
 
   Array<tir::Schedule> res;
   for (const tir::LoopRV& split_loop : split_loops) {
@@ -104,7 +103,7 @@ Array<tir::Schedule> AddRFactorNode::Apply(const 
tir::Schedule& sch, const tir::
 
       // Annotate that the rfactor block, which is now the producer of the 
original block, needs to
       // be considered by the rule Random-Compute-Location.
-      sch_tmp->Annotate(block_rv, 
tir::attr::meta_schedule_random_compute_producer, Bool(true));
+      sch_tmp->Annotate(block_rv, 
tir::attr::meta_schedule_random_compute_producer, Integer(1));
       res.push_back(sch_tmp);
     } catch (const tvm::runtime::Error& e) {
     }
diff --git a/src/tir/schedule/primitive/sampling.cc 
b/src/tir/schedule/primitive/sampling.cc
index b1001a7f94..ec12b045d3 100644
--- a/src/tir/schedule/primitive/sampling.cc
+++ b/src/tir/schedule/primitive/sampling.cc
@@ -338,7 +338,9 @@ std::vector<int64_t> SamplePerfectTile(
   } else {
     // Case 3. Use fresh new sampling result
     result = SamplePerfectTile(rand_state, *extent, n_splits, 
max_innermost_factor);
-    ICHECK_LE(result.back(), max_innermost_factor);
+    if (max_innermost_factor != -1) {
+      ICHECK_LE(result.back(), max_innermost_factor);
+    }
   }
   *decision = support::AsArray<int64_t, Integer>(result);
   return result;
diff --git 
a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py 
b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
index a39c8aea5f..17f42654fc 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
@@ -15,62 +15,108 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: 
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
-
-from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
+from tvm import meta_schedule as ms
 from tvm.meta_schedule.testing import te_workload
-from tvm.meta_schedule.testing.schedule_rule import add_rfactor
-from tvm.meta_schedule.testing.space_generation import check_trace
-from tvm.meta_schedule.tune_context import TuneContext
+from tvm.meta_schedule.testing.space_generation import check_sketches
+from tvm.script import tir as T
 from tvm.target import Target
-from tvm.te.operation import create_prim_func
+from tvm.te import create_prim_func
 
 
-def _create_context(mod, target, rule) -> TuneContext:
-    ctx = TuneContext(
-        mod=mod,
-        target=target,
-        space_generator=PostOrderApply(),
-        sch_rules=[rule],
-        task_name="test",
-    )
-    return ctx
+def test_cpu_matmul():
+    @T.prim_func
+    def cpu_matmul_0(
+        A: T.Buffer[(4, 512), "float32"],
+        B: T.Buffer[(512, 4), "float32"],
+        C: T.Buffer[(4, 4), "float32"],
+    ) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        for i0, i1, i2 in T.grid(4, 4, 512):
+            with T.block("C"):
+                i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+                T.reads(A[i, k], B[k, j])
+                T.writes(C[i, j])
+                with T.init():
+                    C[i, j] = T.float32(0)
+                C[i, j] = C[i, j] + A[i, k] * B[k, j]
 
+    @T.prim_func
+    def cpu_matmul_1(
+        A: T.Buffer[(4, 512), "float32"],
+        B: T.Buffer[(512, 4), "float32"],
+        C: T.Buffer[(4, 4), "float32"],
+    ) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        C_rf = T.alloc_buffer([4, 4, 128], dtype="float32")
+        for i0, i1, i2_0, i2_1 in T.grid(4, 4, 4, 128):
+            with T.block("C_rf"):
+                vi2_1, i, j, vi2_0 = T.axis.remap("SSSR", [i2_1, i0, i1, i2_0])
+                T.reads(A[i, vi2_0 * 128 + vi2_1], B[vi2_0 * 128 + vi2_1, j])
+                T.writes(C_rf[i, j, vi2_1])
+                with T.init():
+                    C_rf[i, j, vi2_1] = T.float32(0)
+                C_rf[i, j, vi2_1] = (
+                    C_rf[i, j, vi2_1] + A[i, vi2_0 * 128 + vi2_1] * B[vi2_0 * 
128 + vi2_1, j]
+                )
+        for i0, i1, i2_1 in T.grid(4, 4, 128):
+            with T.block("C"):
+                vi2_1, i, j = T.axis.remap("RSS", [i2_1, i0, i1])
+                T.reads(C_rf[i, j, vi2_1])
+                T.writes(C[i, j])
+                T.block_attr({"meta_schedule.random_compute_producer": 1})
+                with T.init():
+                    C[i, j] = T.float32(0)
+                C[i, j] = C[i, j] + C_rf[i, j, vi2_1]
 
-def test_cpu_matmul():
-    expected = [
-        [],
-        [
-            'b0 = sch.get_block(name="C", func_name="main")',
-            "l1, l2, l3 = sch.get_loops(block=b0)",
-            "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, 
max_innermost_factor=64)",
-            "l6, l7 = sch.split(loop=l3, factors=[v4, v5], 
preserve_unit_iters=True)",
-            "b8 = sch.rfactor(loop=l7, factor_axis=2)",
-            'sch.annotate(block_or_loop=b0, 
ann_key="meta_schedule.random_compute_producer", ann_val=1)',
-        ],
-        [
-            'b0 = sch.get_block(name="C", func_name="main")',
-            "l1, l2, l3 = sch.get_loops(block=b0)",
-            "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, 
max_innermost_factor=64)",
-            "l6, l7 = sch.split(loop=l3, factors=[v4, v5], 
preserve_unit_iters=True)",
-            "b8 = sch.rfactor(loop=l6, factor_axis=2)",
-            'sch.annotate(block_or_loop=b0, 
ann_key="meta_schedule.random_compute_producer", ann_val=1)',
-        ],
+    @T.prim_func
+    def cpu_matmul_2(
+        A: T.Buffer[(4, 512), "float32"],
+        B: T.Buffer[(512, 4), "float32"],
+        C: T.Buffer[(4, 4), "float32"],
+    ) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        C_rf = T.alloc_buffer([4, 4, 4], dtype="float32")
+        for i0, i1, i2_0, i2_1 in T.grid(4, 4, 4, 128):
+            with T.block("C_rf"):
+                vi2_0, i, j, vi2_1 = T.axis.remap("SSSR", [i2_0, i0, i1, i2_1])
+                T.reads(A[i, vi2_0 * 128 + vi2_1], B[vi2_0 * 128 + vi2_1, j])
+                T.writes(C_rf[i, j, vi2_0])
+                with T.init():
+                    C_rf[i, j, vi2_0] = T.float32(0)
+                C_rf[i, j, vi2_0] = (
+                    C_rf[i, j, vi2_0] + A[i, vi2_0 * 128 + vi2_1] * B[vi2_0 * 
128 + vi2_1, j]
+                )
+        for i0, i1, i2_0 in T.grid(4, 4, 4):
+            with T.block("C"):
+                vi2_0, i, j = T.axis.remap("RSS", [i2_0, i0, i1])
+                T.reads(C_rf[i, j, vi2_0])
+                T.writes(C[i, j])
+                T.block_attr({"meta_schedule.random_compute_producer": 1})
+                with T.init():
+                    C[i, j] = T.float32(0)
+                C[i, j] = C[i, j] + C_rf[i, j, vi2_0]
+
+    decision_0 = []  # type: ignore
+    decision_1 = [
+        ("SamplePerfectTile", [4, 128]),
+    ]
+    decision_2 = [
+        ("SamplePerfectTile", [4, 128]),
     ]
-    target = Target("llvm --num-cores=32")
-    ctx = _create_context(
-        create_prim_func(
-            te_workload.matmul(
-                n=4,
-                m=4,
-                k=512,
-            )
-        ),
-        target=target,
-        rule=add_rfactor(target=target),
+    mod = create_prim_func(te_workload.matmul(n=4, m=4, k=512))
+    actual = ms.TuneContext(
+        mod=mod,
+        target=Target("llvm --num-cores=32"),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules=[ms.schedule_rule.AddRFactor()],
+        task_name="test",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[cpu_matmul_0, cpu_matmul_1, cpu_matmul_2],
+        expected_decisions=[decision_0, decision_1, decision_2],
     )
-    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
-    assert len(spaces) == 3
-    check_trace(spaces, expected)
 
 
 if __name__ == "__main__":

Reply via email to