This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new eeae66b301 [Unity][MetaSchedule] Skip Scheduled PrimFuncs in Task
Generation (#14402)
eeae66b301 is described below
commit eeae66b301361193997ee37a44b1f364a1cfb861
Author: Xiyou Zhou <[email protected]>
AuthorDate: Tue Apr 11 07:51:18 2023 -0700
[Unity][MetaSchedule] Skip Scheduled PrimFuncs in Task Generation (#14402)
This PR introduced a check to skip tasks that has been scheduled either by
DefaultGPUSchedule pass or the MetaScheduleApplyDatabase pass.
---
python/tvm/meta_schedule/relax_integration.py | 4 ++++
src/tir/transforms/default_gpu_schedule.cc | 23 +++++++++++++++++++++-
.../relax/test_transform_meta_schedule_tuning.py | 4 ++--
.../test_transform_default_gpu_schedule.py | 14 ++++++-------
4 files changed, 35 insertions(+), 10 deletions(-)
diff --git a/python/tvm/meta_schedule/relax_integration.py
b/python/tvm/meta_schedule/relax_integration.py
index db22214b76..62f5865242 100644
--- a/python/tvm/meta_schedule/relax_integration.py
+++ b/python/tvm/meta_schedule/relax_integration.py
@@ -16,6 +16,7 @@
# under the License.
"""Meta schedule integration with high-level IR"""
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+import warnings
# isort: off
from typing_extensions import Literal
@@ -124,6 +125,9 @@ def extracted_tasks_to_tune_contexts(
get_loggers_from_work_dir(work_dir, [t.task_name for t in
extracted_tasks]),
fork_seed(seed, n=len(extracted_tasks)),
):
+ if task.mod.attrs is not None and
task.mod.attrs.get("tir.is_scheduled", False):
+ warnings.warn("The task {task.task_name} is already scheduled,
skipping it.")
+ continue
tasks.append(
TuneContext(
mod=task.dispatched[0],
diff --git a/src/tir/transforms/default_gpu_schedule.cc
b/src/tir/transforms/default_gpu_schedule.cc
index 2b56dda0d6..78197a9078 100644
--- a/src/tir/transforms/default_gpu_schedule.cc
+++ b/src/tir/transforms/default_gpu_schedule.cc
@@ -74,6 +74,27 @@ void ThreadBind(tir::Schedule sch, const tir::BlockRV&
block, int64_t max_thread
}
}
+IRModule MarkScheduled(const IRModule& mod) {
+ Map<GlobalVar, BaseFunc> result;
+
+ for (const auto& [gv, base_func] : mod->functions) {
+ if (const auto* prim_func_node = base_func.as<tir::PrimFuncNode>()) {
+ tir::PrimFunc prim_func = GetRef<tir::PrimFunc>(prim_func_node);
+ tir::PrimFunc new_prim_func =
+ WithAttr(std::move(prim_func), tir::attr::kIsScheduled, Bool(true));
+ result.Set(gv, new_prim_func);
+ } else {
+ result.Set(gv, base_func);
+ }
+ }
+
+ return IRModule(result, // functions
+ mod->type_definitions, // type_definitions
+ mod->import_set_, // import_set
+ mod->source_map, // map
+ mod->attrs); // attrs);
+}
+
Pass DefaultGPUSchedule() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
[=](IRModule m, PassContext pc) {
@@ -96,7 +117,7 @@ Pass DefaultGPUSchedule() {
}
}
}
- return sch->mod();
+ return MarkScheduled(sch->mod());
};
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
diff --git a/tests/python/relax/test_transform_meta_schedule_tuning.py
b/tests/python/relax/test_transform_meta_schedule_tuning.py
index 13c81ba962..39331548e4 100644
--- a/tests/python/relax/test_transform_meta_schedule_tuning.py
+++ b/tests/python/relax/test_transform_meta_schedule_tuning.py
@@ -120,7 +120,7 @@ class DefaultScheduledModule:
B: T.Buffer((32, 32), "float32"),
C: T.Buffer((32, 32), "float32"),
):
- T.func_attr({"global_symbol": "tir_matmul"})
+ T.func_attr({"global_symbol": "tir_matmul", "tir.is_scheduled": True})
# with T.block("root"):
for i0_j0_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for i0_j0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
@@ -137,7 +137,7 @@ class DefaultScheduledModule:
@T.prim_func
def tir_relu(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32, 32),
"float32")):
- T.func_attr({"global_symbol": "tir_relu"})
+ T.func_attr({"global_symbol": "tir_relu", "tir.is_scheduled": True})
# with T.block("root"):
for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
diff --git a/tests/python/unittest/test_transform_default_gpu_schedule.py
b/tests/python/unittest/test_transform_default_gpu_schedule.py
index 2503a7009d..85d78d6bba 100644
--- a/tests/python/unittest/test_transform_default_gpu_schedule.py
+++ b/tests/python/unittest/test_transform_default_gpu_schedule.py
@@ -50,7 +50,7 @@ def test_broadcast_to_symbolic():
rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"),
var_T_broadcast_to: T.handle,
):
- T.func_attr({"tir.noalias": True})
+ T.func_attr({"tir.noalias": True, "tir.is_scheduled": True})
x_0 = T.int64()
x_1 = T.int64()
T_broadcast_to = T.match_buffer(var_T_broadcast_to, (x_0, x_1))
@@ -128,7 +128,7 @@ def test_matmul():
B: T.Buffer((32, 32), "float16"),
C: T.Buffer((32, 32), "float16"),
):
- T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ T.func_attr({"tir.is_scheduled": True, "global_symbol": "main",
"tir.noalias": True})
# with T.block("root"):
for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for i_j_fused_1 in T.thread_binding(1024,
thread="threadIdx.x"):
@@ -179,7 +179,7 @@ def test_add():
),
T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)),
"float32"),
):
- T.func_attr({"tir.noalias": True})
+ T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1),
thread="blockIdx.x"):
for i0_i1_i2_i3_fused_1 in T.thread_binding(
@@ -248,7 +248,7 @@ def test_full():
rxplaceholder: T.Buffer((), "int32"),
T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"),
):
- T.func_attr({"tir.noalias": True})
+ T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
for i0_i1_fused_0 in T.thread_binding(T.int64(1),
thread="blockIdx.x"):
for i0_i1_fused_1 in T.thread_binding(T.int64(6),
thread="threadIdx.x"):
@@ -284,7 +284,7 @@ def test_scheduled():
rxplaceholder: T.Buffer((), "int32"),
T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"),
):
- T.func_attr({"tir.noalias": True})
+ T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
for i0_i1_fused_0 in T.thread_binding(T.int64(1),
thread="blockIdx.x"):
for i0_i1_fused_1 in T.thread_binding(T.int64(6),
thread="threadIdx.x"):
@@ -345,7 +345,7 @@ def test_multiple():
),
T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)),
"float32"),
):
- T.func_attr({"tir.noalias": True})
+ T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1),
thread="blockIdx.x"):
for i0_i1_i2_i3_fused_1 in T.thread_binding(
@@ -389,7 +389,7 @@ def test_multiple():
rxplaceholder: T.Buffer((), "int32"),
T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"),
):
- T.func_attr({"tir.noalias": True})
+ T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
for i0_i1_fused_0 in T.thread_binding(T.int64(1),
thread="blockIdx.x"):
for i0_i1_fused_1 in T.thread_binding(T.int64(6),
thread="threadIdx.x"):