This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a23b71ce1e [MetaSchedule][Test] Migrate AddRFactor to SEqual (#12758)
a23b71ce1e is described below
commit a23b71ce1e3011be6b8e6ca5162b023956358911
Author: Junru Shao <[email protected]>
AuthorDate: Mon Sep 12 15:42:40 2022 -0800
[MetaSchedule][Test] Migrate AddRFactor to SEqual (#12758)
This PR migrates the usage of `check_trace` to `check_sketch`,
which prefers structural equality of TIRs insteda of string equalty
of traces.
---
python/tvm/meta_schedule/testing/schedule_rule.py | 16 +--
python/tvm/tir/schedule/testing.py | 8 +-
src/meta_schedule/schedule_rule/add_rfactor.cc | 5 +-
src/tir/schedule/primitive/sampling.cc | 4 +-
...test_meta_schedule_schedule_rule_add_rfactor.py | 142 ++++++++++++++-------
5 files changed, 109 insertions(+), 66 deletions(-)
diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py
b/python/tvm/meta_schedule/testing/schedule_rule.py
index 46df4b95ce..b08db0811d 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -18,7 +18,6 @@
from typing import List, Union
from tvm.meta_schedule.schedule_rule import (
- AddRFactor,
AutoBind,
AutoInline,
CrossThreadReduction,
@@ -28,7 +27,9 @@ from tvm.meta_schedule.schedule_rule import (
ReuseType,
ScheduleRule,
)
-from tvm.meta_schedule.schedule_rule.multi_level_tiling import
MultiLevelTilingTensorCore
+from tvm.meta_schedule.schedule_rule.multi_level_tiling import (
+ MultiLevelTilingTensorCore,
+)
from tvm.target import Target
@@ -64,13 +65,6 @@ def auto_inline(target: Target) -> ScheduleRule:
raise NotImplementedError(f"{target.kind.name} is not supported")
-def add_rfactor(target: Target) -> ScheduleRule:
- """Default schedule rules for with add_rfactor"""
- if target.kind.name == "llvm":
- return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64)
- raise NotImplementedError(f"{target.kind.name} is not supported")
-
-
def cross_thread_reduction(target: Target) -> ScheduleRule:
"""Default schedule rules for with cross-thread reduction"""
if target.kind.name == "cuda":
@@ -131,7 +125,9 @@ def multi_level_tiling_tensor_core(
trans_b = [trans_b]
if target.kind.name == "cuda":
- from tvm.tir.tensor_intrin import cuda # pylint:
disable=import-outside-toplevel
+ from tvm.tir.tensor_intrin import ( # pylint:
disable=import-outside-toplevel
+ cuda,
+ )
intrin_groups = [
cuda.get_wmma_intrin_group(write_reuse_scope, _in_dtype,
_out_dtype, _trans_b)
diff --git a/python/tvm/tir/schedule/testing.py
b/python/tvm/tir/schedule/testing.py
index 3689f756e8..538cc6e143 100644
--- a/python/tvm/tir/schedule/testing.py
+++ b/python/tvm/tir/schedule/testing.py
@@ -15,12 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""Testing utilities for the TensorIR schedule API"""
-from typing import Union, Sequence
+from typing import Sequence, Union
import tvm
-from tvm.ir import IRModule, structural_equal
+from tvm.ir import IRModule, assert_structural_equal
from tvm.tir import PrimFunc
-from tvm.tir.schedule import Trace, Schedule
+from tvm.tir.schedule import Schedule, Trace
def verify_trace_roundtrip(
@@ -70,7 +70,7 @@ def verify_trace_roundtrip(
assert text_format in ("json", "python"), f"Unknown text format:
{text_format}"
# Step 2. Verify that the round-trip produced the same scheduling
- assert structural_equal(new_sch.mod, sch.mod)
+ assert_structural_equal(new_sch.mod, sch.mod)
# Step 3. Check the consistency of the text format between the old and new
traces
py_repr = "\n".join(trace.as_python())
diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc
b/src/meta_schedule/schedule_rule/add_rfactor.cc
index 5ef2ac3aad..cf87f24ac2 100644
--- a/src/meta_schedule/schedule_rule/add_rfactor.cc
+++ b/src/meta_schedule/schedule_rule/add_rfactor.cc
@@ -90,8 +90,7 @@ Array<tir::Schedule> AddRFactorNode::Apply(const
tir::Schedule& sch, const tir::
// Split the fused reduction loop.
Array<tir::ExprRV> factors = sch->SamplePerfectTile(fused_reduce_loop, 2,
max_innermost_factor);
- const Array<tir::LoopRV>& split_loops =
- sch->Split(fused_reduce_loop, {factors.begin(), factors.end()});
+ Array<tir::LoopRV> split_loops = sch->Split(fused_reduce_loop,
{factors.begin(), factors.end()});
Array<tir::Schedule> res;
for (const tir::LoopRV& split_loop : split_loops) {
@@ -104,7 +103,7 @@ Array<tir::Schedule> AddRFactorNode::Apply(const
tir::Schedule& sch, const tir::
// Annotate that the rfactor block, which is now the producer of the
original block, needs to
// be considered by the rule Random-Compute-Location.
- sch_tmp->Annotate(block_rv,
tir::attr::meta_schedule_random_compute_producer, Bool(true));
+ sch_tmp->Annotate(block_rv,
tir::attr::meta_schedule_random_compute_producer, Integer(1));
res.push_back(sch_tmp);
} catch (const tvm::runtime::Error& e) {
}
diff --git a/src/tir/schedule/primitive/sampling.cc
b/src/tir/schedule/primitive/sampling.cc
index b1001a7f94..ec12b045d3 100644
--- a/src/tir/schedule/primitive/sampling.cc
+++ b/src/tir/schedule/primitive/sampling.cc
@@ -338,7 +338,9 @@ std::vector<int64_t> SamplePerfectTile(
} else {
// Case 3. Use fresh new sampling result
result = SamplePerfectTile(rand_state, *extent, n_splits,
max_innermost_factor);
- ICHECK_LE(result.back(), max_innermost_factor);
+ if (max_innermost_factor != -1) {
+ ICHECK_LE(result.back(), max_innermost_factor);
+ }
}
*decision = support::AsArray<int64_t, Integer>(result);
return result;
diff --git
a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
index a39c8aea5f..17f42654fc 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
@@ -15,62 +15,108 @@
# specific language governing permissions and limitations
# under the License.
# pylint:
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
-
-from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
+from tvm import meta_schedule as ms
from tvm.meta_schedule.testing import te_workload
-from tvm.meta_schedule.testing.schedule_rule import add_rfactor
-from tvm.meta_schedule.testing.space_generation import check_trace
-from tvm.meta_schedule.tune_context import TuneContext
+from tvm.meta_schedule.testing.space_generation import check_sketches
+from tvm.script import tir as T
from tvm.target import Target
-from tvm.te.operation import create_prim_func
+from tvm.te import create_prim_func
-def _create_context(mod, target, rule) -> TuneContext:
- ctx = TuneContext(
- mod=mod,
- target=target,
- space_generator=PostOrderApply(),
- sch_rules=[rule],
- task_name="test",
- )
- return ctx
+def test_cpu_matmul():
+ @T.prim_func
+ def cpu_matmul_0(
+ A: T.Buffer[(4, 512), "float32"],
+ B: T.Buffer[(512, 4), "float32"],
+ C: T.Buffer[(4, 4), "float32"],
+ ) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ for i0, i1, i2 in T.grid(4, 4, 512):
+ with T.block("C"):
+ i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+ T.reads(A[i, k], B[k, j])
+ T.writes(C[i, j])
+ with T.init():
+ C[i, j] = T.float32(0)
+ C[i, j] = C[i, j] + A[i, k] * B[k, j]
+ @T.prim_func
+ def cpu_matmul_1(
+ A: T.Buffer[(4, 512), "float32"],
+ B: T.Buffer[(512, 4), "float32"],
+ C: T.Buffer[(4, 4), "float32"],
+ ) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ C_rf = T.alloc_buffer([4, 4, 128], dtype="float32")
+ for i0, i1, i2_0, i2_1 in T.grid(4, 4, 4, 128):
+ with T.block("C_rf"):
+ vi2_1, i, j, vi2_0 = T.axis.remap("SSSR", [i2_1, i0, i1, i2_0])
+ T.reads(A[i, vi2_0 * 128 + vi2_1], B[vi2_0 * 128 + vi2_1, j])
+ T.writes(C_rf[i, j, vi2_1])
+ with T.init():
+ C_rf[i, j, vi2_1] = T.float32(0)
+ C_rf[i, j, vi2_1] = (
+ C_rf[i, j, vi2_1] + A[i, vi2_0 * 128 + vi2_1] * B[vi2_0 *
128 + vi2_1, j]
+ )
+ for i0, i1, i2_1 in T.grid(4, 4, 128):
+ with T.block("C"):
+ vi2_1, i, j = T.axis.remap("RSS", [i2_1, i0, i1])
+ T.reads(C_rf[i, j, vi2_1])
+ T.writes(C[i, j])
+ T.block_attr({"meta_schedule.random_compute_producer": 1})
+ with T.init():
+ C[i, j] = T.float32(0)
+ C[i, j] = C[i, j] + C_rf[i, j, vi2_1]
-def test_cpu_matmul():
- expected = [
- [],
- [
- 'b0 = sch.get_block(name="C", func_name="main")',
- "l1, l2, l3 = sch.get_loops(block=b0)",
- "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2,
max_innermost_factor=64)",
- "l6, l7 = sch.split(loop=l3, factors=[v4, v5],
preserve_unit_iters=True)",
- "b8 = sch.rfactor(loop=l7, factor_axis=2)",
- 'sch.annotate(block_or_loop=b0,
ann_key="meta_schedule.random_compute_producer", ann_val=1)',
- ],
- [
- 'b0 = sch.get_block(name="C", func_name="main")',
- "l1, l2, l3 = sch.get_loops(block=b0)",
- "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2,
max_innermost_factor=64)",
- "l6, l7 = sch.split(loop=l3, factors=[v4, v5],
preserve_unit_iters=True)",
- "b8 = sch.rfactor(loop=l6, factor_axis=2)",
- 'sch.annotate(block_or_loop=b0,
ann_key="meta_schedule.random_compute_producer", ann_val=1)',
- ],
+ @T.prim_func
+ def cpu_matmul_2(
+ A: T.Buffer[(4, 512), "float32"],
+ B: T.Buffer[(512, 4), "float32"],
+ C: T.Buffer[(4, 4), "float32"],
+ ) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ C_rf = T.alloc_buffer([4, 4, 4], dtype="float32")
+ for i0, i1, i2_0, i2_1 in T.grid(4, 4, 4, 128):
+ with T.block("C_rf"):
+ vi2_0, i, j, vi2_1 = T.axis.remap("SSSR", [i2_0, i0, i1, i2_1])
+ T.reads(A[i, vi2_0 * 128 + vi2_1], B[vi2_0 * 128 + vi2_1, j])
+ T.writes(C_rf[i, j, vi2_0])
+ with T.init():
+ C_rf[i, j, vi2_0] = T.float32(0)
+ C_rf[i, j, vi2_0] = (
+ C_rf[i, j, vi2_0] + A[i, vi2_0 * 128 + vi2_1] * B[vi2_0 *
128 + vi2_1, j]
+ )
+ for i0, i1, i2_0 in T.grid(4, 4, 4):
+ with T.block("C"):
+ vi2_0, i, j = T.axis.remap("RSS", [i2_0, i0, i1])
+ T.reads(C_rf[i, j, vi2_0])
+ T.writes(C[i, j])
+ T.block_attr({"meta_schedule.random_compute_producer": 1})
+ with T.init():
+ C[i, j] = T.float32(0)
+ C[i, j] = C[i, j] + C_rf[i, j, vi2_0]
+
+ decision_0 = [] # type: ignore
+ decision_1 = [
+ ("SamplePerfectTile", [4, 128]),
+ ]
+ decision_2 = [
+ ("SamplePerfectTile", [4, 128]),
]
- target = Target("llvm --num-cores=32")
- ctx = _create_context(
- create_prim_func(
- te_workload.matmul(
- n=4,
- m=4,
- k=512,
- )
- ),
- target=target,
- rule=add_rfactor(target=target),
+ mod = create_prim_func(te_workload.matmul(n=4, m=4, k=512))
+ actual = ms.TuneContext(
+ mod=mod,
+ target=Target("llvm --num-cores=32"),
+ space_generator=ms.space_generator.PostOrderApply(),
+ sch_rules=[ms.schedule_rule.AddRFactor()],
+ task_name="test",
+ ).generate_design_space()
+ check_sketches(
+ mod,
+ sketches=actual,
+ expected_mods=[cpu_matmul_0, cpu_matmul_1, cpu_matmul_2],
+ expected_decisions=[decision_0, decision_1, decision_2],
)
- spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
- assert len(spaces) == 3
- check_trace(spaces, expected)
if __name__ == "__main__":