This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new eeae66b301 [Unity][MetaSchedule] Skip Scheduled PrimFuncs in Task 
Generation (#14402)
eeae66b301 is described below

commit eeae66b301361193997ee37a44b1f364a1cfb861
Author: Xiyou Zhou <[email protected]>
AuthorDate: Tue Apr 11 07:51:18 2023 -0700

    [Unity][MetaSchedule] Skip Scheduled PrimFuncs in Task Generation (#14402)
    
    This PR introduced a check to skip tasks that has been scheduled either by 
DefaultGPUSchedule pass or the MetaScheduleApplyDatabase pass.
---
 python/tvm/meta_schedule/relax_integration.py      |  4 ++++
 src/tir/transforms/default_gpu_schedule.cc         | 23 +++++++++++++++++++++-
 .../relax/test_transform_meta_schedule_tuning.py   |  4 ++--
 .../test_transform_default_gpu_schedule.py         | 14 ++++++-------
 4 files changed, 35 insertions(+), 10 deletions(-)

diff --git a/python/tvm/meta_schedule/relax_integration.py 
b/python/tvm/meta_schedule/relax_integration.py
index db22214b76..62f5865242 100644
--- a/python/tvm/meta_schedule/relax_integration.py
+++ b/python/tvm/meta_schedule/relax_integration.py
@@ -16,6 +16,7 @@
 # under the License.
 """Meta schedule integration with high-level IR"""
 from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+import warnings
 
 # isort: off
 from typing_extensions import Literal
@@ -124,6 +125,9 @@ def extracted_tasks_to_tune_contexts(
         get_loggers_from_work_dir(work_dir, [t.task_name for t in 
extracted_tasks]),
         fork_seed(seed, n=len(extracted_tasks)),
     ):
+        if task.mod.attrs is not None and 
task.mod.attrs.get("tir.is_scheduled", False):
+            warnings.warn("The task {task.task_name} is already scheduled, 
skipping it.")
+            continue
         tasks.append(
             TuneContext(
                 mod=task.dispatched[0],
diff --git a/src/tir/transforms/default_gpu_schedule.cc 
b/src/tir/transforms/default_gpu_schedule.cc
index 2b56dda0d6..78197a9078 100644
--- a/src/tir/transforms/default_gpu_schedule.cc
+++ b/src/tir/transforms/default_gpu_schedule.cc
@@ -74,6 +74,27 @@ void ThreadBind(tir::Schedule sch, const tir::BlockRV& 
block, int64_t max_thread
   }
 }
 
+IRModule MarkScheduled(const IRModule& mod) {
+  Map<GlobalVar, BaseFunc> result;
+
+  for (const auto& [gv, base_func] : mod->functions) {
+    if (const auto* prim_func_node = base_func.as<tir::PrimFuncNode>()) {
+      tir::PrimFunc prim_func = GetRef<tir::PrimFunc>(prim_func_node);
+      tir::PrimFunc new_prim_func =
+          WithAttr(std::move(prim_func), tir::attr::kIsScheduled, Bool(true));
+      result.Set(gv, new_prim_func);
+    } else {
+      result.Set(gv, base_func);
+    }
+  }
+
+  return IRModule(result,                 // functions
+                  mod->type_definitions,  // type_definitions
+                  mod->import_set_,       // import_set
+                  mod->source_map,        // map
+                  mod->attrs);            // attrs);
+}
+
 Pass DefaultGPUSchedule() {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
       [=](IRModule m, PassContext pc) {
@@ -96,7 +117,7 @@ Pass DefaultGPUSchedule() {
             }
           }
         }
-        return sch->mod();
+        return MarkScheduled(sch->mod());
       };
   return CreateModulePass(/*pass_function=*/pass_func,         //
                           /*opt_level=*/0,                     //
diff --git a/tests/python/relax/test_transform_meta_schedule_tuning.py 
b/tests/python/relax/test_transform_meta_schedule_tuning.py
index 13c81ba962..39331548e4 100644
--- a/tests/python/relax/test_transform_meta_schedule_tuning.py
+++ b/tests/python/relax/test_transform_meta_schedule_tuning.py
@@ -120,7 +120,7 @@ class DefaultScheduledModule:
         B: T.Buffer((32, 32), "float32"),
         C: T.Buffer((32, 32), "float32"),
     ):
-        T.func_attr({"global_symbol": "tir_matmul"})
+        T.func_attr({"global_symbol": "tir_matmul", "tir.is_scheduled": True})
         # with T.block("root"):
         for i0_j0_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
             for i0_j0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
@@ -137,7 +137,7 @@ class DefaultScheduledModule:
 
     @T.prim_func
     def tir_relu(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32, 32), 
"float32")):
-        T.func_attr({"global_symbol": "tir_relu"})
+        T.func_attr({"global_symbol": "tir_relu", "tir.is_scheduled": True})
         # with T.block("root"):
         for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
             for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
diff --git a/tests/python/unittest/test_transform_default_gpu_schedule.py 
b/tests/python/unittest/test_transform_default_gpu_schedule.py
index 2503a7009d..85d78d6bba 100644
--- a/tests/python/unittest/test_transform_default_gpu_schedule.py
+++ b/tests/python/unittest/test_transform_default_gpu_schedule.py
@@ -50,7 +50,7 @@ def test_broadcast_to_symbolic():
             rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"),
             var_T_broadcast_to: T.handle,
         ):
-            T.func_attr({"tir.noalias": True})
+            T.func_attr({"tir.noalias": True, "tir.is_scheduled": True})
             x_0 = T.int64()
             x_1 = T.int64()
             T_broadcast_to = T.match_buffer(var_T_broadcast_to, (x_0, x_1))
@@ -128,7 +128,7 @@ def test_matmul():
             B: T.Buffer((32, 32), "float16"),
             C: T.Buffer((32, 32), "float16"),
         ):
-            T.func_attr({"global_symbol": "main", "tir.noalias": True})
+            T.func_attr({"tir.is_scheduled": True, "global_symbol": "main", 
"tir.noalias": True})
             # with T.block("root"):
             for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
                 for i_j_fused_1 in T.thread_binding(1024, 
thread="threadIdx.x"):
@@ -179,7 +179,7 @@ def test_add():
             ),
             T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), 
"float32"),
         ):
-            T.func_attr({"tir.noalias": True})
+            T.func_attr({"tir.is_scheduled": True,  "tir.noalias": True})
             # with T.block("root"):
             for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
                 for i0_i1_i2_i3_fused_1 in T.thread_binding(
@@ -248,7 +248,7 @@ def test_full():
             rxplaceholder: T.Buffer((), "int32"),
             T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"),
         ):
-            T.func_attr({"tir.noalias": True})
+            T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
             # with T.block("root"):
             for i0_i1_fused_0 in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
                 for i0_i1_fused_1 in T.thread_binding(T.int64(6), 
thread="threadIdx.x"):
@@ -284,7 +284,7 @@ def test_scheduled():
             rxplaceholder: T.Buffer((), "int32"),
             T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"),
         ):
-            T.func_attr({"tir.noalias": True})
+            T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
             # with T.block("root"):
             for i0_i1_fused_0 in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
                 for i0_i1_fused_1 in T.thread_binding(T.int64(6), 
thread="threadIdx.x"):
@@ -345,7 +345,7 @@ def test_multiple():
             ),
             T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), 
"float32"),
         ):
-            T.func_attr({"tir.noalias": True})
+            T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
             # with T.block("root"):
             for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
                 for i0_i1_i2_i3_fused_1 in T.thread_binding(
@@ -389,7 +389,7 @@ def test_multiple():
             rxplaceholder: T.Buffer((), "int32"),
             T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"),
         ):
-            T.func_attr({"tir.noalias": True})
+            T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
             # with T.block("root"):
             for i0_i1_fused_0 in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
                 for i0_i1_fused_1 in T.thread_binding(T.int64(6), 
thread="threadIdx.x"):

Reply via email to