junrushao commented on code in PR #12895:
URL: https://github.com/apache/tvm/pull/12895#discussion_r985387124


##########
python/tvm/meta_schedule/task_scheduler/task_scheduler.py:
##########
@@ -101,15 +90,43 @@ def join_running_task(self, task_id: int) -> 
List[RunnerResult]:
         """
         return _ffi_api.TaskSchedulerJoinRunningTask(self, task_id)  # type: 
ignore # pylint: disable=no-member
 
-    def initialize_task(self, task_id: int) -> None:
-        """Initialize modules of the given task.
+    def tune(
+        self,
+        tasks: List[TuneContext],
+        task_weights: List[float],
+        max_trials_global: int,
+        max_trials_per_task: int,
+        num_trials_per_iter: int,
+        builder: Builder,
+        runner: Runner,
+        measure_callbacks: List[MeasureCallback],
+        database: Optional[Database],
+        cost_model: Optional[CostModel],
+    ) -> None:
+        """Auto-tuning."""

Review Comment:
   thanks! will do!



##########
python/tvm/meta_schedule/relay_integration.py:
##########
@@ -15,28 +15,82 @@
 # specific language governing permissions and limitations
 # under the License.
 """MetaSchedule-Relay integration"""
-from typing import Any, Dict, List, Optional
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Tuple, Union
 
+# isort: off
+from typing_extensions import Literal
+
+# isort: on
 import numpy as np  # type: ignore
 from tvm import nd
 from tvm._ffi import get_global_func
 from tvm.ir import IRModule, transform
 from tvm.runtime import NDArray
 from tvm.target import Target
 
+from .builder import Builder
+from .cost_model import CostModel
+from .database import Database
 from .extracted_task import ExtractedTask
-from .utils import autotvm_silencer
+from .logging import get_loggers_from_work_dir
+from .measure_callback import MeasureCallback
+from .profiler import Profiler
+from .runner import Runner
+from .search_strategy import SearchStrategy
+from .space_generator import SpaceGenerator
+from .task_scheduler import TaskScheduler
+from .tune import tune_tasks
+from .tune_context import TuneContext
+from .utils import fork_seed
+
+_extract_task = get_global_func(  # pylint: disable=invalid-name
+    "relay.backend.MetaScheduleExtractTask",
+    allow_missing=False,
+)
+
+
+@contextmanager
+def _autotvm_silencer():
+    """A context manager that silences autotvm warnings."""
+    from tvm import autotvm  # pylint: disable=import-outside-toplevel
+
+    silent = autotvm.GLOBAL_SCOPE.silent
+    autotvm.GLOBAL_SCOPE.silent = True
+    try:
+        yield
+    finally:
+        autotvm.GLOBAL_SCOPE.silent = silent
+
+
+def _normalize_params(
+    mod: IRModule,
+    target: Union[Target, str],
+    params: Optional[Dict[str, NDArray]],
+) -> Tuple[IRModule, Target, Dict[str, NDArray]]:
+    from tvm.relay import Function  # pylint: disable=import-outside-toplevel
+
+    if isinstance(mod, Function):
+        mod = IRModule.from_expr(mod)
+    if not isinstance(target, Target):
+        target = Target(target)
+    if params is None:
+        params = {}
+    relay_params = {}
+    for name, param in params.items():
+        if isinstance(param, np.ndarray):
+            param = nd.array(param)
+        relay_params[name] = param
+
+    return mod, target, relay_params
 
 
 def extract_task_from_relay(
     mod: IRModule,
-    target: Target,
-    params: Optional[Dict[str, NDArray]] = None,
+    target: Union[Target, str],
+    params: Optional[Dict[str, NDArray]],
     *,
-    opt_level: int = 3,
-    pass_config: Optional[Dict[str, Any]] = None,

Review Comment:
   i believe as the scope of this PR, it's good that we focus on existing 
functionalities as tested, while get it expanded to more advanced cases could 
be a separate PR



##########
src/meta_schedule/schedule_rule/schedule_rule.cc:
##########
@@ -51,6 +51,125 @@ ScheduleRule ScheduleRule::PyScheduleRule(
   return ScheduleRule(n);
 }
 
+Array<ScheduleRule> ScheduleRule::DefaultLLVM() {
+  return {
+      ScheduleRule::AutoInline(
+          /*into_producer=*/false,
+          /*into_consumer=*/true,
+          /*inline_const_tensor=*/true,
+          /*disallow_if_then_else=*/true,
+          /*require_injective=*/true,
+          /*require_ordered=*/true,
+          /*disallow_op=*/Array<String>{"tir.exp"}),
+      ScheduleRule::AddRFactor(
+          /*max_jobs_per_core=*/16,
+          /*max_innermost_factor=*/Integer(64)),
+      ScheduleRule::MultiLevelTiling(
+          /*structure=*/"SSRSRS",
+          /*tile_binds=*/NullOpt,
+          /*max_innermost_factor=*/Integer(64),
+          /*vector_load_lens=*/NullOpt,
+          /*reuse_read=*/NullOpt,
+          /*reuse_write=*/
+          Map<String, ObjectRef>{{"req", String("may")},
+                                 {"levels", Array<Integer>{1, 2}},
+                                 {"scope", String("global")}}),
+      ScheduleRule::ParallelizeVectorizeUnroll(
+          /*max_jobs_per_core=*/16,
+          /*max_vectorize_extent=*/64,
+          /*unroll_max_steps=*/Array<Integer>{0, 16, 64, 512},
+          /*unroll_explicit=*/true),
+      ScheduleRule::RandomComputeLocation(),
+  };
+}
+
+Array<ScheduleRule> ScheduleRule::DefaultCUDA() {
+  return {
+      ScheduleRule::MultiLevelTiling(
+          /*structure=*/"SSSRRSRS",
+          /*tile_binds=*/Array<String>{"blockIdx.x", "vthread.x", 
"threadIdx.x"},
+          /*max_innermost_factor=*/Integer(64),
+          /*vector_load_lens=*/Array<Integer>{1, 2, 3, 4, 8, 16},
+          /*reuse_read=*/
+          Map<String, ObjectRef>{{"req", String("must")},
+                                 {"levels", Array<Integer>{4}},  //
+                                 {"scope", String("shared")}},
+          /*reuse_write=*/
+          Map<String, ObjectRef>{{"req", String("must")},
+                                 {"levels", Array<Integer>{3}},  //
+                                 {"scope", String("local")}}),
+      ScheduleRule::AutoInline(
+          /*into_producer=*/true,
+          /*into_consumer=*/true,
+          /*inline_const_tensor=*/true,
+          /*disallow_if_then_else=*/false,
+          /*require_injective=*/false,
+          /*require_ordered=*/false,
+          /*disallow_op=*/Array<String>{}),
+      ScheduleRule::CrossThreadReduction(
+          /*thread_extents=*/Array<Integer>{4, 8, 16, 32, 64, 128, 256, 512}),
+      ScheduleRule::ParallelizeVectorizeUnroll(
+          /*max_jobs_per_core=*/-1,
+          /*max_vectorize_extent=*/-1,
+          /*unroll_max_steps=*/Array<Integer>{0, 16, 64, 512, 1024},
+          /*unroll_explicit=*/true),
+      ScheduleRule::AutoBind(
+          /*max_threadblocks=*/256,
+          /*thread_extents*/ Array<Integer>{32, 64, 128, 256, 512, 1024}),
+  };
+}
+
+Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() {
+  Array<Map<String, String>> intrin_groups = {
+      {
+          {"init", "wmma_fill_16x16x16_f16"},
+          {"load_a", "wmma_load_16x16x16_f16_a"},
+          {"load_b", "wmma_load_16x16x16_f16_b"},
+          {"compute", "wmma_sync_16x16x16_f16f16f16"},
+          {"store", "wmma_store_16x16x16_f16_shared"},
+      },
+      {
+          {"init", "wmma_fill_16x16x16_f16"},
+          {"load_a", "wmma_load_16x16x16_f16_a"},
+          {"load_b", "wmma_load_16x16x16_f16_b_trans"},
+          {"compute", "wmma_sync_16x16x16_f16f16f16_trans"},
+          {"store", "wmma_store_16x16x16_f16_shared"},
+      },
+      {
+          {"init", "wmma_fill_16x16x16_s32"},
+          {"load_a", "wmma_load_16x16x16_s8_a"},
+          {"load_b", "wmma_load_16x16x16_s8_b"},
+          {"compute", "wmma_sync_16x16x16_s8s8s32"},
+          {"store", "wmma_store_16x16x16_s32_shared"},
+      },
+      {
+          {"init", "wmma_fill_16x16x16_s32"},
+          {"load_a", "wmma_load_16x16x16_s8_a"},
+          {"load_b", "wmma_load_16x16x16_s8_b_trans"},
+          {"compute", "wmma_sync_16x16x16_s8s8s32_trans"},
+          {"store", "wmma_store_16x16x16_s32_shared"},
+      },
+  };
+  Array<ScheduleRule> results{ScheduleRule::MultiLevelTilingTensorCore(
+      /*intrin_groups=*/intrin_groups,
+      /*structure=*/"SSSRRSRS",
+      /*tile_binds=*/Array<String>{"blockIdx.x", "vthread.x", "threadIdx.x"},
+      /*max_innermost_factor=*/Integer(64),
+      /*vector_load_lens=*/Array<Integer>{1, 2, 3, 4, 8, 16},
+      /*reuse_read=*/
+      Map<String, ObjectRef>{{"req", String("must")},
+                             {"levels", Array<Integer>{4}},  //
+                             {"scope", String("shared")}},
+      /*reuse_write=*/
+      Map<String, ObjectRef>{{"req", String("must")},
+                             {"levels", Array<Integer>{3}},  //
+                             {"scope", String("local")}},

Review Comment:
   marked resolved as no further conversation happened in this thread - feel 
free to unresolve if it's not fixed :-)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to