This is an automated email from the ASF dual-hosted git repository.

xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 558ba99c7c [MetaSchedule] Tuning Script Upgrade (#11797)
558ba99c7c is described below

commit 558ba99c7cad6fa5f01cfdb2bd6bdd2cec6087db
Author: Xiyou Zhou <[email protected]>
AuthorDate: Wed Jun 29 21:11:41 2022 -0700

    [MetaSchedule] Tuning Script Upgrade (#11797)
    
    * Support uint8.
    
    * Modify tuning functions.
    
    * Follow legacy setting, use int32 for uint8.
    
    * Add vm support.
    
    * Fix vm usage.
    
    * Use vm in rpc run module.
    
    * Fix lint & stuff.
    
    * Fix backend.
    
    * Fix ftimer.
    
    * Fix lint.
    
    * Limit backend choice.
    
    * Add try catch.
    
    * Display name in rpc try catch.
    
    * Support ahb from tune_relay.
    
    * Modify scripts.
    
    * Fix typo.
    
    * Minor fix.
    
    * Fix try catch & func name.
    
    * Fix utils.
    
    * Move utils to tune_utils.
    
    * Fix tune_utils.
---
 python/tvm/auto_scheduler/testing/tune_onnx.py     | 150 +++++++---------
 python/tvm/auto_scheduler/testing/tune_relay.py    | 145 ++++++---------
 python/tvm/auto_scheduler/testing/tune_te.py       |  97 ++++++-----
 python/tvm/meta_schedule/cost_model/cost_model.py  |   2 +-
 .../meta_schedule/testing/custom_builder_runner.py |  14 +-
 python/tvm/meta_schedule/testing/tune_onnx.py      |  86 +++------
 python/tvm/meta_schedule/testing/tune_relay.py     |  84 +++------
 python/tvm/meta_schedule/testing/tune_te.py        |  16 +-
 python/tvm/meta_schedule/testing/tune_utils.py     | 194 +++++++++++++++++++++
 python/tvm/meta_schedule/testing/utils.py          |   3 +-
 python/tvm/meta_schedule/tune.py                   |  20 ++-
 11 files changed, 448 insertions(+), 363 deletions(-)

diff --git a/python/tvm/auto_scheduler/testing/tune_onnx.py 
b/python/tvm/auto_scheduler/testing/tune_onnx.py
index 5444794cf1..a3299c05bb 100644
--- a/python/tvm/auto_scheduler/testing/tune_onnx.py
+++ b/python/tvm/auto_scheduler/testing/tune_onnx.py
@@ -15,18 +15,18 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-docstring
+from distutils.util import strtobool
 import argparse
 import json
 import os
-
-from distutils.util import strtobool
-import numpy as np  # type: ignore
 import onnx  # type: ignore
+
 import tvm
 from tvm import auto_scheduler
 from tvm import meta_schedule as ms
 from tvm import relay
 from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
+from tvm.meta_schedule.testing.tune_utils import generate_input_data, 
create_timer
 from tvm.meta_schedule.utils import cpu_count
 from tvm.relay.frontend import from_onnx
 from tvm.support import describe
@@ -96,17 +96,23 @@ def _parse_args():
         default=100,
     )
     args.add_argument(
-        "--cpu-flush",
+        "--adaptive-training",
         type=lambda x: bool(strtobool(x)),
-        required=True,
         help="example: True / False",
+        default=True,
     )
     args.add_argument(
-        "--adaptive-training",
+        "--cpu-flush",
         type=lambda x: bool(strtobool(x)),
-        required=False,
         help="example: True / False",
-        default=True,
+        required=True,
+    )
+    args.add_argument(
+        "--backend",
+        type=str,
+        choices=["graph", "vm"],
+        help="example: graph / vm",
+        required=True,
     )
     parsed = args.parse_args()
     parsed.target = tvm.target.Target(parsed.target)
@@ -135,6 +141,7 @@ def main():
         repeat=ARGS.repeat,
         min_repeat_ms=ARGS.min_repeat_ms,
         enable_cpu_cache_flush=ARGS.cpu_flush,
+        timeout=ARGS.rpc_config.session_timeout_sec,
     )
 
     if ARGS.target.kind.name == "llvm":
@@ -163,102 +170,63 @@ def main():
     onnx_model = onnx.load(ARGS.onnx_path)
     shape_dict = {}
     for item in ARGS.input_shape:
-        print(f"  input_name: {item['name']}")
+        print(f"  input_name : {item['name']}")
         print(f"  input_shape: {item['shape']}")
         print(f"  input_dtype: {item['dtype']}")
         shape_dict[item["name"]] = item["shape"]
     mod, params = from_onnx(onnx_model, shape_dict, freeze_params=True)
-    tasks, task_weights = auto_scheduler.extract_tasks(
-        mod["main"],
-        params,
-        target=ARGS.target,
-        hardware_params=hardware_params,
-    )
-    for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
-        print(f"==== Task {idx}: {task.desc} (weight {task_weight} key: 
{task.workload_key}) =====")
-        print(task.compute_dag)
-
-    tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
-    tuner.tune(
-        auto_scheduler.TuningOptions(
-            num_measure_trials=ARGS.num_trials,
-            runner=runner,
-            measure_callbacks=[
-                auto_scheduler.RecordToFile(log_file),
-            ],
-        ),
-        adaptive_training=ARGS.adaptive_training,
-    )
-
-    with auto_scheduler.ApplyHistoryBest(log_file):
-        with tvm.transform.PassContext(
-            opt_level=3,
-            config={"relay.backend.use_auto_scheduler": True},
-        ):
-            lib = relay.build(
-                mod,
-                target=ARGS.target,
-                params=params,
+    input_data = {
+        item["name"]: generate_input_data(item["shape"], item["dtype"]) for 
item in ARGS.input_shape
+    }
+
+    with ms.Profiler() as profiler:
+        tasks, task_weights = auto_scheduler.extract_tasks(
+            mod["main"],
+            params,
+            target=ARGS.target,
+            hardware_params=hardware_params,
+        )
+        for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
+            print(
+                f"==== Task {idx}: {task.desc} "
+                f"(weight {task_weight} key: {task.workload_key}) ====="
             )
-    graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
-    input_data = {}
-    for item in ARGS.input_shape:
-        input_name, input_shape, input_dtype = item["name"], item["shape"], 
item["dtype"]
-        if input_dtype.startswith("float"):
-            input_data[input_name] = 
np.random.uniform(size=input_shape).astype(input_dtype)
-        else:
-            input_data[input_name] = np.random.randint(
-                low=0, high=10000, size=input_shape, dtype=input_dtype
+            print(task.compute_dag)
+
+        if ARGS.num_trials > 0:
+            tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
+            tuner.tune(
+                auto_scheduler.TuningOptions(
+                    num_measure_trials=ARGS.num_trials,
+                    runner=runner,
+                    measure_callbacks=[
+                        auto_scheduler.RecordToFile(log_file),
+                    ],
+                ),
+                adaptive_training=ARGS.adaptive_training,
             )
 
-    def f_timer(rt_mod, dev, input_data):
-        # pylint: disable=import-outside-toplevel
-        from tvm.contrib.graph_executor import GraphModule
-
-        # pylint: enable=import-outside-toplevel
-
-        mod = GraphModule(rt_mod["default"](dev))
-        for input_name, input_value in input_data.items():
-            mod.set_input(input_name, input_value)
-        ftimer = mod.module.time_evaluator(
-            "run",
-            dev,
-            min_repeat_ms=500,
-            repeat=3,
-        )
-        results = list(np.array(ftimer().results) * 1000.0)  # type: ignore
-        print("Running time in time_evaluator: ", results)
+        relay_build = {"graph": relay.build, "vm": 
relay.vm.compile}[ARGS.backend]
+        with auto_scheduler.ApplyHistoryBest(log_file):
+            with tvm.transform.PassContext(
+                opt_level=3,
+                config={"relay.backend.use_auto_scheduler": True},
+            ):
+                lib = relay_build(
+                    mod,
+                    target=ARGS.target,
+                    params=params,
+                )
+    print("Tuning Time:")
+    print(profiler.table())
 
     run_module_via_rpc(
         rpc_config=ARGS.rpc_config,
         lib=lib,
         dev_type=ARGS.target.kind.name,
         args=input_data,
-        continuation=f_timer,
-    )
-
-    def f_per_layer(rt_mod, dev, input_data):
-        # pylint: disable=import-outside-toplevel
-        from tvm.contrib.debugger.debug_executor import create
-
-        # pylint: enable=import-outside-toplevel
-        mod = create(graph, rt_mod, dev)
-        for input_name, input_value in input_data.items():
-            mod.set_input(input_name, input_value)
-        graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]]
-        graph_time = mod.run_individual(number=10, repeat=1, 
min_repeat_ms=5000)
-        print("|graph_nodes| = ", len(graph_nodes))
-        print("|graph_time| = ", len(graph_time))
-        graph_nodes_time = {k: float(v) for k, v in zip(graph_nodes, 
graph_time)}
-        for k, v in graph_nodes_time.items():
-            print(f"{k} : {v:.3f}")
-
-    run_module_via_rpc(
-        rpc_config=ARGS.rpc_config,
-        lib=rt_mod,
-        dev_type=ARGS.target.kind.name,
-        args=input_data,
-        continuation=f_per_layer,
+        continuation=create_timer(ARGS.backend),
+        backend=ARGS.backend,
     )
 
 
diff --git a/python/tvm/auto_scheduler/testing/tune_relay.py 
b/python/tvm/auto_scheduler/testing/tune_relay.py
index fedb27281a..fe747af797 100644
--- a/python/tvm/auto_scheduler/testing/tune_relay.py
+++ b/python/tvm/auto_scheduler/testing/tune_relay.py
@@ -15,18 +15,18 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-docstring
+from distutils.util import strtobool
 import argparse
 import json
 import os
 
-from distutils.util import strtobool
-import numpy as np  # type: ignore
 import tvm
 from tvm import auto_scheduler
 from tvm import meta_schedule as ms
 from tvm import relay
 from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
 from tvm.meta_schedule.testing.relay_workload import get_network
+from tvm.meta_schedule.testing.tune_utils import generate_input_data, 
create_timer
 from tvm.meta_schedule.utils import cpu_count
 from tvm.support import describe
 
@@ -94,17 +94,23 @@ def _parse_args():
         default=100,
     )
     args.add_argument(
-        "--cpu-flush",
+        "--adaptive-training",
         type=lambda x: bool(strtobool(x)),
-        required=True,
         help="example: True / False",
+        default=True,
     )
     args.add_argument(
-        "--adaptive-training",
+        "--cpu-flush",
         type=lambda x: bool(strtobool(x)),
-        required=False,
         help="example: True / False",
-        default=True,
+        required=True,
+    )
+    args.add_argument(
+        "--backend",
+        type=str,
+        choices=["graph", "vm"],
+        help="example: graph / vm",
+        required=True,
     )
     parsed = args.parse_args()
     parsed.target = tvm.target.Target(parsed.target)
@@ -133,6 +139,7 @@ def main():
         repeat=ARGS.repeat,
         min_repeat_ms=ARGS.min_repeat_ms,
         enable_cpu_cache_flush=ARGS.cpu_flush,
+        timeout=ARGS.rpc_config.session_timeout_sec,
     )
 
     if ARGS.target.kind.name == "llvm":
@@ -164,100 +171,62 @@ def main():
         cache_dir=ARGS.cache_dir,
     )
     input_info = {input_name: input_shape}
-    input_data = {}
+    input_data = {
+        item["name"]: generate_input_data(item["shape"], item["dtype"]) for 
item in ARGS.input_shape
+    }
     for input_name, input_shape in input_info.items():
-        print(f"  input_name: {input_name}")
+        print(f"  input_name : {input_name}")
         print(f"  input_shape: {input_shape}")
         print(f"  input_dtype: {input_dtype}")
-    tasks, task_weights = auto_scheduler.extract_tasks(
-        mod["main"],
-        params,
-        target=ARGS.target,
-        hardware_params=hardware_params,
-    )
-    for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
-        print(f"==== Task {idx}: {task.desc} (weight {task_weight} key: 
{task.workload_key}) =====")
-        print(task.compute_dag)
 
-    tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
-    tuner.tune(
-        auto_scheduler.TuningOptions(
-            num_measure_trials=ARGS.num_trials,
-            runner=runner,
-            measure_callbacks=[
-                auto_scheduler.RecordToFile(log_file),
-            ],
-        ),
-        adaptive_training=ARGS.adaptive_training,
-    )
-
-    with auto_scheduler.ApplyHistoryBest(log_file):
-        with tvm.transform.PassContext(
-            opt_level=3,
-            config={"relay.backend.use_auto_scheduler": True},
-        ):
-            lib = relay.build(
-                mod,
-                target=ARGS.target,
-                params=params,
+    with ms.Profiler() as profiler:
+        tasks, task_weights = auto_scheduler.extract_tasks(
+            mod["main"],
+            params,
+            target=ARGS.target,
+            hardware_params=hardware_params,
+        )
+        for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
+            print(
+                f"==== Task {idx}: {task.desc} "
+                f"(weight {task_weight} key: {task.workload_key}) ====="
             )
-    graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
-    for input_name, input_shape in input_info.items():
-        if input_dtype.startswith("float"):
-            input_data[input_name] = 
np.random.uniform(size=input_shape).astype(input_dtype)
-        else:
-            input_data[input_name] = np.random.randint(
-                low=0, high=10000, size=input_shape, dtype=input_dtype
+            print(task.compute_dag)
+
+        if ARGS.num_trials > 0:
+            tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
+            tuner.tune(
+                auto_scheduler.TuningOptions(
+                    num_measure_trials=ARGS.num_trials,
+                    runner=runner,
+                    measure_callbacks=[
+                        auto_scheduler.RecordToFile(log_file),
+                    ],
+                ),
+                adaptive_training=ARGS.adaptive_training,
             )
 
-    def f_timer(rt_mod, dev, input_data):
-        # pylint: disable=import-outside-toplevel
-        from tvm.contrib.graph_executor import GraphModule
-
-        # pylint: enable=import-outside-toplevel
-
-        mod = GraphModule(rt_mod["default"](dev))
-        for input_name, input_value in input_data.items():
-            mod.set_input(input_name, input_value)
-        ftimer = mod.module.time_evaluator(
-            "run",
-            dev,
-            min_repeat_ms=500,
-            repeat=3,
-        )
-        results = list(np.array(ftimer().results) * 1000.0)  # type: ignore
-        print("Running time in time_evaluator: ", results)
+        relay_build = {"graph": relay.build, "vm": 
relay.vm.compile}[ARGS.backend]
+        with auto_scheduler.ApplyHistoryBest(log_file):
+            with tvm.transform.PassContext(
+                opt_level=3,
+                config={"relay.backend.use_auto_scheduler": True},
+            ):
+                lib = relay_build(
+                    mod,
+                    target=ARGS.target,
+                    params=params,
+                )
+    print("Tuning Time:")
+    print(profiler.table())
 
     run_module_via_rpc(
         rpc_config=ARGS.rpc_config,
         lib=lib,
         dev_type=ARGS.target.kind.name,
         args=input_data,
-        continuation=f_timer,
-    )
-
-    def f_per_layer(rt_mod, dev, input_data):
-        # pylint: disable=import-outside-toplevel
-        from tvm.contrib.debugger.debug_executor import create
-
-        # pylint: enable=import-outside-toplevel
-        mod = create(graph, rt_mod, dev)
-        for input_name, input_value in input_data.items():
-            mod.set_input(input_name, input_value)
-        graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]]
-        graph_time = mod.run_individual(number=10, repeat=1, 
min_repeat_ms=5000)
-        print("|graph_nodes| = ", len(graph_nodes))
-        print("|graph_time| = ", len(graph_time))
-        graph_nodes_time = {k: float(np.mean(v)) for k, v in zip(graph_nodes, 
graph_time)}
-        for k, v in graph_nodes_time.items():
-            print(f"{k} : {v:.3f}")
-
-    run_module_via_rpc(
-        rpc_config=ARGS.rpc_config,
-        lib=rt_mod,
-        dev_type=ARGS.target.kind.name,
-        args=input_data,
-        continuation=f_per_layer,
+        continuation=create_timer(ARGS.backend),
+        backend=ARGS.backend,
     )
 
 
diff --git a/python/tvm/auto_scheduler/testing/tune_te.py 
b/python/tvm/auto_scheduler/testing/tune_te.py
index c6a5ab27cf..da3584512d 100644
--- a/python/tvm/auto_scheduler/testing/tune_te.py
+++ b/python/tvm/auto_scheduler/testing/tune_te.py
@@ -15,12 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-docstring
+from distutils.util import strtobool
 import argparse
 import os
-from distutils.util import strtobool
 
 import tvm
 from tvm import auto_scheduler
+from tvm import meta_schedule as ms
 from tvm.meta_schedule.testing.te_workload import CONFIGS
 from tvm.meta_schedule.utils import cpu_count
 from tvm.support import describe
@@ -79,20 +80,26 @@ def _parse_args():
         default=100,
     )
     args.add_argument(
-        "--cpu-flush",
+        "--adaptive-training",
         type=lambda x: bool(strtobool(x)),
-        required=True,
+        required=False,
         help="example: True / False",
+        default=True,
     )
     args.add_argument(
-        "--adaptive-training",
+        "--cpu-flush",
         type=lambda x: bool(strtobool(x)),
-        required=False,
         help="example: True / False",
-        default=True,
+        required=True,
     )
     parsed = args.parse_args()
     parsed.target = tvm.target.Target(parsed.target)
+    parsed.rpc_config = ms.runner.RPCConfig(
+        tracker_host=parsed.rpc_host,
+        tracker_port=parsed.rpc_port,
+        tracker_key=parsed.rpc_key,
+        session_timeout_sec=60,
+    )
     return parsed
 
 
@@ -100,12 +107,19 @@ ARGS = _parse_args()
 
 
 def main():
-    describe()
-    print(f"Workload: {ARGS.workload}")
     log_file = os.path.join(ARGS.work_dir, f"{ARGS.workload}.json")
-    workload_func, params = CONFIGS[ARGS.workload]
-    params = params[0]  # type: ignore
-    workload_func = auto_scheduler.register_workload(workload_func)
+
+    runner = auto_scheduler.RPCRunner(
+        key=ARGS.rpc_key,
+        host=ARGS.rpc_host,
+        port=ARGS.rpc_port,
+        n_parallel=cpu_count(logical=True),
+        number=ARGS.number,
+        repeat=ARGS.repeat,
+        min_repeat_ms=ARGS.min_repeat_ms,
+        enable_cpu_cache_flush=ARGS.cpu_flush,
+        timeout=ARGS.rpc_config.session_timeout_sec,
+    )
 
     if ARGS.target.kind.name == "llvm":
         hardware_params = auto_scheduler.HardwareParams(
@@ -127,37 +141,42 @@ def main():
         )
     else:
         raise NotImplementedError(f"Unsupported target {ARGS.target}")
-    task = auto_scheduler.SearchTask(
-        func=workload_func,
-        args=params,
-        target=ARGS.target,
-        hardware_params=hardware_params,
-    )
-    runner = auto_scheduler.RPCRunner(
-        key=ARGS.rpc_key,
-        host=ARGS.rpc_host,
-        port=ARGS.rpc_port,
-        n_parallel=cpu_count(logical=True),
-        number=ARGS.number,
-        repeat=ARGS.repeat,
-        min_repeat_ms=ARGS.min_repeat_ms,
-        enable_cpu_cache_flush=ARGS.cpu_flush,
-        # todo(zxybazh): set session timeout to 60 same as MS
-    )
 
-    # Inspect the computational graph
-    print("Computational DAG:")
-    print(task.compute_dag)
-    tune_option = auto_scheduler.TuningOptions(
-        num_measure_trials=ARGS.num_trials,
-        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
-        verbose=2,
-        runner=runner,
-    )
-    print("Running AutoTuning:")
-    task.tune(tune_option, adaptive_training=ARGS.adaptive_training)
+    describe()
+    print(f"Workload: {ARGS.workload}")
+    with ms.Profiler() as profiler:
+        # Same as MetaSchedule Tune TE
+        # Does not count ApplyHistoryBest time
+
+        workload_func, params = CONFIGS[ARGS.workload]
+        params = params[0]  # type: ignore
+        workload_func = auto_scheduler.register_workload(workload_func)
+
+        task = auto_scheduler.SearchTask(
+            func=workload_func,
+            args=params,
+            target=ARGS.target,
+            hardware_params=hardware_params,
+        )
+        # Inspect the computational graph
+        print("Computational DAG:")
+        print(task.compute_dag)
+        tune_option = auto_scheduler.TuningOptions(
+            num_measure_trials=ARGS.num_trials,
+            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
+            verbose=2,
+            runner=runner,
+        )
+        if ARGS.num_trials > 0:
+            print("Running AutoTuning:")
+            task.tune(tune_option, adaptive_training=ARGS.adaptive_training)
+
+    print("Tuning Time:")
+    print(profiler.table())
+
     print("History Best:")
     print(task.print_best(log_file))
+
     sch, args = task.apply_best(log_file)
     print("Lowered TIR:")
     print(tvm.lower(sch, args, simple_mode=True))
diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py 
b/python/tvm/meta_schedule/cost_model/cost_model.py
index e479cb7254..2fdb9b9349 100644
--- a/python/tvm/meta_schedule/cost_model/cost_model.py
+++ b/python/tvm/meta_schedule/cost_model/cost_model.py
@@ -73,7 +73,7 @@ class CostModel(Object):
         _ffi_api.CostModelUpdate(self, context, candidates, results)  # type: 
ignore # pylint: disable=no-member
 
     def predict(self, context: TuneContext, candidates: 
List[MeasureCandidate]) -> np.ndarray:
-        """Update the cost model given running results.
+        """Predict normalized score with the cost model.
 
         Parameters
         ----------
diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py 
b/python/tvm/meta_schedule/testing/custom_builder_runner.py
index 3ba007d9a4..e203848c2c 100644
--- a/python/tvm/meta_schedule/testing/custom_builder_runner.py
+++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py
@@ -17,7 +17,7 @@
 """Customized builder and runner methods"""
 # pylint: disable=import-outside-toplevel
 
-from typing import TYPE_CHECKING, Callable, Dict, List
+from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
 
 if TYPE_CHECKING:
     import numpy as np  # type: ignore
@@ -25,6 +25,7 @@ if TYPE_CHECKING:
     from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
     from tvm.runtime import Device, Module, NDArray
     from tvm.target import Target
+    from tvm.runtime.vm import Executable
 
 
 def build_relay(
@@ -143,10 +144,11 @@ def run_with_graph_executor(
 
 def run_module_via_rpc(
     rpc_config: "RPCConfig",
-    lib: "Module",
+    lib: Union["Module", "Executable"],
     dev_type: str,
     args: Dict[str, "np.ndarray"],
     continuation: Callable,
+    backend: Optional[str] = "graph",
 ):
     """Execute a tvm.runtime.Module on RPC remote"""
     # pylint: disable=import-outside-toplevel
@@ -160,13 +162,15 @@ def run_module_via_rpc(
 
     with tempfile.TemporaryDirectory() as tmp_dir:
         filename = os.path.join(tmp_dir, "tvm_tmp_mod." + tar.output_format)
+        if backend == "vm":
+            code, lib = lib.save()
         lib.export_library(filename, tar)
         session = rpc_config.connect_server()
         session.upload(filename)
         _, filename = os.path.split(filename)
         rt_mod = session.load_module(filename)
+        if backend == "vm":
+            rt_mod = session.get_function("runtime.Load_Executable")(code, 
rt_mod)
         dev = session.device(dev_type=dev_type, dev_id=0)
-        nd_args = {}
-        for arg_key, arg_value in args.items():
-            nd_args[arg_key] = ndarray.array(arg_value, dev)
+        nd_args = {k: ndarray.array(v, dev) for k, v in args.items()}
         return continuation(rt_mod, dev, nd_args)
diff --git a/python/tvm/meta_schedule/testing/tune_onnx.py 
b/python/tvm/meta_schedule/testing/tune_onnx.py
index 8ae9ab1ed0..6d473ed323 100644
--- a/python/tvm/meta_schedule/testing/tune_onnx.py
+++ b/python/tvm/meta_schedule/testing/tune_onnx.py
@@ -15,18 +15,18 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-docstring
+from distutils.util import strtobool
 import argparse
 import json
 import logging
-
-from distutils.util import strtobool
-import numpy as np  # type: ignore
 import onnx  # type: ignore
+
 import tvm
 from tvm import meta_schedule as ms
 from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
 from tvm.relay.frontend import from_onnx
 from tvm.support import describe
+from .tune_utils import generate_input_data, create_timer
 
 
 def _parse_args():
@@ -93,17 +93,23 @@ def _parse_args():
         default=100,
     )
     args.add_argument(
-        "--cpu-flush",
+        "--adaptive-training",
         type=lambda x: bool(strtobool(x)),
-        required=True,
         help="example: True / False",
+        default=True,
     )
     args.add_argument(
-        "--adaptive-training",
+        "--cpu-flush",
         type=lambda x: bool(strtobool(x)),
-        required=False,
         help="example: True / False",
-        default=True,
+        required=True,
+    )
+    args.add_argument(
+        "--backend",
+        type=str,
+        choices=["graph", "vm"],
+        help="example: graph / vm",
+        required=True,
     )
     parsed = args.parse_args()
     parsed.target = tvm.target.Target(parsed.target)
@@ -127,14 +133,19 @@ ARGS = _parse_args()
 def main():
     describe()
     print(f"Workload: {ARGS.model_name}")
+
     onnx_model = onnx.load(ARGS.onnx_path)
     shape_dict = {}
     for item in ARGS.input_shape:
-        print(f"  input_name: {item['name']}")
+        print(f"  input_name : {item['name']}")
         print(f"  input_shape: {item['shape']}")
         print(f"  input_dtype: {item['dtype']}")
         shape_dict[item["name"]] = item["shape"]
     mod, params = from_onnx(onnx_model, shape_dict, freeze_params=True)
+    input_data = {
+        item["name"]: generate_input_data(item["shape"], item["dtype"]) for 
item in ARGS.input_shape
+    }
+
     runner = ms.runner.RPCRunner(
         rpc_config=ARGS.rpc_config,
         evaluator_config=ms.runner.EvaluatorConfig(
@@ -145,6 +156,7 @@ def main():
         ),
         alloc_repeat=1,
     )
+
     with ms.Profiler() as profiler:
         lib = ms.tune_relay(
             mod=mod,
@@ -159,68 +171,18 @@ def main():
             runner=runner,  # type: ignore
             work_dir=ARGS.work_dir,
             params=params,
+            backend=ARGS.backend,
         )
     print("Tuning Time:")
     print(profiler.table())
-    graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
-    input_data = {}
-    for item in ARGS.input_shape:
-        input_name, input_shape, input_dtype = item["name"], item["shape"], 
item["dtype"]
-        if input_dtype.startswith("float"):
-            input_data[input_name] = 
np.random.uniform(size=input_shape).astype(input_dtype)
-        else:
-            input_data[input_name] = np.random.randint(
-                low=0, high=10000, size=input_shape, dtype=input_dtype
-            )
-
-    def f_timer(rt_mod, dev, input_data):
-        # pylint: disable=import-outside-toplevel
-        from tvm.contrib.graph_executor import GraphModule
-
-        # pylint: enable=import-outside-toplevel
-
-        mod = GraphModule(rt_mod["default"](dev))
-        for input_name, input_value in input_data.items():
-            mod.set_input(input_name, input_value)
-        ftimer = mod.module.time_evaluator(
-            "run",
-            dev,
-            min_repeat_ms=500,
-            repeat=3,
-        )
-        results = list(np.array(ftimer().results) * 1000.0)  # type: ignore
-        print("Running time in time_evaluator: ", results)
 
     run_module_via_rpc(
         rpc_config=ARGS.rpc_config,
         lib=lib,
         dev_type=ARGS.target.kind.name,
         args=input_data,
-        continuation=f_timer,
-    )
-
-    def f_per_layer(rt_mod, dev, input_data):
-        # pylint: disable=import-outside-toplevel
-        from tvm.contrib.debugger.debug_executor import create
-
-        # pylint: enable=import-outside-toplevel
-        mod = create(graph, rt_mod, dev)
-        for input_name, input_value in input_data.items():
-            mod.set_input(input_name, input_value)
-        graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]]
-        graph_time = mod.run_individual(number=10, repeat=1, 
min_repeat_ms=5000)
-        print("|graph_nodes| = ", len(graph_nodes))
-        print("|graph_time| = ", len(graph_time))
-        graph_nodes_time = {k: float(v) for k, v in zip(graph_nodes, 
graph_time)}
-        for k, v in graph_nodes_time.items():
-            print(f"{k} : {v:.3f}")
-
-    run_module_via_rpc(
-        rpc_config=ARGS.rpc_config,
-        lib=rt_mod,
-        dev_type=ARGS.target.kind.name,
-        args=input_data,
-        continuation=f_per_layer,
+        continuation=create_timer(ARGS.backend),
+        backend=ARGS.backend,
     )
 
 
diff --git a/python/tvm/meta_schedule/testing/tune_relay.py 
b/python/tvm/meta_schedule/testing/tune_relay.py
index daef48daa2..8010e36fd6 100644
--- a/python/tvm/meta_schedule/testing/tune_relay.py
+++ b/python/tvm/meta_schedule/testing/tune_relay.py
@@ -15,16 +15,16 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-docstring
+from distutils.util import strtobool
 import argparse
 import json
 import logging
 
-from distutils.util import strtobool
-import numpy as np  # type: ignore
 import tvm
 from tvm import meta_schedule as ms
 from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
 from tvm.meta_schedule.testing.relay_workload import get_network
+from tvm.meta_schedule.testing.tune_utils import generate_input_data, 
create_timer
 from tvm.support import describe
 
 
@@ -91,17 +91,23 @@ def _parse_args():
         default=100,
     )
     args.add_argument(
-        "--cpu-flush",
+        "--adaptive-training",
         type=lambda x: bool(strtobool(x)),
-        required=True,
         help="example: True / False",
+        default=True,
     )
     args.add_argument(
-        "--adaptive-training",
+        "--cpu-flush",
         type=lambda x: bool(strtobool(x)),
-        required=False,
         help="example: True / False",
-        default=True,
+        required=True,
+    )
+    args.add_argument(
+        "--backend",
+        type=str,
+        choices=["graph", "vm"],
+        help="example: graph / vm",
+        required=True,
     )
     parsed = args.parse_args()
     parsed.target = tvm.target.Target(parsed.target)
@@ -125,17 +131,21 @@ ARGS = _parse_args()
 def main():
     describe()
     print(f"Workload: {ARGS.workload}")
+
     mod, params, (input_name, input_shape, input_dtype) = get_network(
         ARGS.workload,
         ARGS.input_shape,
         cache_dir=ARGS.cache_dir,
     )
     input_info = {input_name: input_shape}
-    input_data = {}
+    input_data = {
+        item["name"]: generate_input_data(item["shape"], item["dtype"]) for 
item in ARGS.input_shape
+    }
     for input_name, input_shape in input_info.items():
-        print(f"  input_name: {input_name}")
+        print(f"  input_name : {input_name}")
         print(f"  input_shape: {input_shape}")
         print(f"  input_dtype: {input_dtype}")
+
     runner = ms.runner.RPCRunner(
         rpc_config=ARGS.rpc_config,
         evaluator_config=ms.runner.EvaluatorConfig(
@@ -146,6 +156,7 @@ def main():
         ),
         alloc_repeat=1,
     )
+
     with ms.Profiler() as profiler:
         lib = ms.tune_relay(
             mod=mod,
@@ -160,66 +171,19 @@ def main():
             runner=runner,  # type: ignore
             work_dir=ARGS.work_dir,
             params=params,
+            backend=ARGS.backend,
         )
+
     print("Tuning Time:")
     print(profiler.table())
-    graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
-    for input_name, input_shape in input_info.items():
-        if input_dtype.startswith("float"):
-            input_data[input_name] = 
np.random.uniform(size=input_shape).astype(input_dtype)
-        else:
-            input_data[input_name] = np.random.randint(
-                low=0, high=10000, size=input_shape, dtype=input_dtype
-            )
-
-    def f_timer(rt_mod, dev, input_data):
-        # pylint: disable=import-outside-toplevel
-        from tvm.contrib.graph_executor import GraphModule
-
-        # pylint: enable=import-outside-toplevel
-
-        mod = GraphModule(rt_mod["default"](dev))
-        for input_name, input_value in input_data.items():
-            mod.set_input(input_name, input_value)
-        ftimer = mod.module.time_evaluator(
-            "run",
-            dev,
-            min_repeat_ms=500,
-            repeat=3,
-        )
-        results = list(np.array(ftimer().results) * 1000.0)  # type: ignore
-        print("Running time in time_evaluator: ", results)
 
     run_module_via_rpc(
         rpc_config=ARGS.rpc_config,
         lib=lib,
         dev_type=ARGS.target.kind.name,
         args=input_data,
-        continuation=f_timer,
-    )
-
-    def f_per_layer(rt_mod, dev, input_data):
-        # pylint: disable=import-outside-toplevel
-        from tvm.contrib.debugger.debug_executor import create
-
-        # pylint: enable=import-outside-toplevel
-        mod = create(graph, rt_mod, dev)
-        for input_name, input_value in input_data.items():
-            mod.set_input(input_name, input_value)
-        graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]]
-        graph_time = mod.run_individual(number=10, repeat=1, 
min_repeat_ms=5000)
-        print("|graph_nodes| = ", len(graph_nodes))
-        print("|graph_time| = ", len(graph_time))
-        graph_nodes_time = {k: float(np.mean(v)) for k, v in zip(graph_nodes, 
graph_time)}
-        for k, v in graph_nodes_time.items():
-            print(f"{k} : {v:.3f}")
-
-    run_module_via_rpc(
-        rpc_config=ARGS.rpc_config,
-        lib=rt_mod,
-        dev_type=ARGS.target.kind.name,
-        args=input_data,
-        continuation=f_per_layer,
+        continuation=create_timer(ARGS.backend),
+        backend=ARGS.backend,
     )
 
 
diff --git a/python/tvm/meta_schedule/testing/tune_te.py 
b/python/tvm/meta_schedule/testing/tune_te.py
index e579c561ad..d54d92048e 100644
--- a/python/tvm/meta_schedule/testing/tune_te.py
+++ b/python/tvm/meta_schedule/testing/tune_te.py
@@ -15,14 +15,14 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=missing-docstring
+from distutils.util import strtobool
 import argparse
 import logging
-from distutils.util import strtobool
 from typing import Optional
 
 import tvm
-from tvm import meta_schedule as ms
 from tvm import tir
+from tvm import meta_schedule as ms
 from tvm.meta_schedule.testing.te_workload import create_te_workload
 from tvm.support import describe
 
@@ -80,17 +80,17 @@ def _parse_args():
         default=100,
     )
     args.add_argument(
-        "--cpu-flush",
+        "--adaptive-training",
         type=lambda x: bool(strtobool(x)),
-        required=True,
+        required=False,
         help="example: True / False",
+        default=True,
     )
     args.add_argument(
-        "--adaptive-training",
+        "--cpu-flush",
         type=lambda x: bool(strtobool(x)),
-        required=False,
         help="example: True / False",
-        default=True,
+        required=True,
     )
     parsed = args.parse_args()
     parsed.target = tvm.target.Target(parsed.target)
@@ -138,8 +138,10 @@ def main():
             task_name=ARGS.workload,
             work_dir=ARGS.work_dir,
         )
+
     print("Tuning Time:")
     print(profiler.table())
+
     if sch is None:
         print("No valid schedule found!")
     else:
diff --git a/python/tvm/meta_schedule/testing/tune_utils.py 
b/python/tvm/meta_schedule/testing/tune_utils.py
new file mode 100644
index 0000000000..aad8496a46
--- /dev/null
+++ b/python/tvm/meta_schedule/testing/tune_utils.py
@@ -0,0 +1,194 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Testing utility functions in meta schedule"""
+from typing import Callable, Optional, Union, List, Dict
+from statistics import median
+import json
+import warnings
+import numpy as np  # type: ignore
+
+import tvm
+from tvm.runtime import NDArray
+
+
+def generate_input_data(
+    input_shape: List[int],
+    input_dtype: str,
+    *,
+    low: Optional[int] = None,
+    high: Optional[int] = None,
+) -> np.ndarray:
+    """Generate input date with given shape and data type.
+
+    Parameters
+    ----------
+    input_shape : List[int]
+        The shape of the input data.
+    input_dtype : str
+        The data type of the input date.
+
+    Returns
+    -------
+    input_data : np.ndarray
+        The generated input data with given shape and data type in numpy 
ndarray.
+    """
+    if input_dtype.startswith("float"):
+        return np.random.uniform(size=input_shape).astype(input_dtype)
+    if input_dtype in ["uint8", "int8"]:
+        return np.random.randint(
+            low=0,
+            high=127,
+            size=input_shape,
+            dtype="int32",  # TODO(zxybazh): fix the datatype when int8 / 
uint8 is supported better
+        )
+    if input_dtype in ["int32", "int64"]:
+        if low is None or high is None:
+            warnings.warn(
+                "Model input value range for shape {input_shape} of 
{input_dtype} is not set!"
+            )
+        return np.random.randint(
+            low=0 if low is None else low,
+            high=10000 if high is None else high,
+            size=input_shape,
+            dtype=input_dtype,
+        )
+    raise ValueError("Unsupported input datatype!")
+
+
+def create_timer(backend: str) -> Callable:
+    """Create a function to run and benchmark the performance of whole given 
runtime module,
+    or Executable in relay vm.
+
+    Parameters
+    ----------
+    backend : str
+        The backend to use, graph / vm.
+
+    Returns
+    -------
+    func : Callable
+        The function to benchmark the workload.
+    """
+
+    def f_timer(
+        rt_mod: Union[tvm.runtime.Module, tvm.runtime.vm.Executable],
+        dev: tvm.device,
+        input_data: Dict[str, NDArray],
+    ) -> None:
+        """Run and benchmark the given runtime module, print out the result.
+
+        Parameters
+        ----------
+        rt_mod : Union[tvm.runtime.Module, tvm.runtime.vm.Executable]
+            The runtime module or vm executable.
+        dev : tvm.device
+            The device type to run workload.
+        input_data : Dict[str, np.ndarray]
+            The input data as a dictionary.
+        """
+        from tvm.contrib.graph_executor import GraphModule  # 
pylint:disable=import-outside-toplevel
+        from tvm.runtime.vm import VirtualMachine  # 
pylint:disable=import-outside-toplevel
+
+        try:
+            if backend == "vm":
+                vm = VirtualMachine(rt_mod, dev)  # pylint: 
disable=invalid-name
+                ftimer = vm.benchmark(
+                    dev, min_repeat_ms=500, repeat=5, number=1, 
end_to_end=False, **input_data
+                )
+            elif backend == "graph":
+                mod = GraphModule(rt_mod["default"](dev))
+                for input_name, input_value in input_data.items():
+                    mod.set_input(input_name, input_value)
+                ftimer = mod.module.time_evaluator(
+                    "run", dev, min_repeat_ms=500, repeat=5, number=1
+                )()
+            else:
+                raise ValueError(f"Backend {backend} not supported in 
f_timer!")
+
+            results = list(np.array(ftimer.results) * 1000.0)  # type: ignore
+
+            print("Running time in time_evaluator: ", results)
+            print("-------------------------------")
+            print(f"    Min (ms) : {min(results)}")
+            print(f"    Max (ms) : {max(results)}")
+            print(f" Median (ms) : {median(results)}")
+            print(f"Average (ms) : {sum(results) / len(results)}")
+        except Exception as exc:  # pylint: disable=broad-except
+            print(
+                f"Run module f_timer via RPC failed, exception: {exc}",
+            )
+
+    return f_timer
+
+
+def create_time_per_layer(graph: str) -> Callable:
+    """Create a function to run and benchmark the per-layer performance of 
given runtime module,
+    given the graph output of the module from graph compiler.
+
+    Parameters
+    ----------
+    graph : str
+        The json format graph output of the module from graph compiler.
+
+    Returns
+    -------
+    func : Callable
+        The function using the json format graph.
+    """
+
+    def f_time_per_layer(
+        rt_mod: tvm.runtime.Module,
+        dev: tvm.device,
+        input_data: Dict[str, NDArray],
+    ) -> None:
+        """Run and benchmark the per-layer performance of given runtime module,
+        print out the result.
+
+        Parameters
+        ----------
+        rt_mod : tvm.runtime.Module
+            The runtime module.
+        dev : tvm.device
+            The device type to run workload.
+        input_data : Dict[str, np.ndarray]
+            The input data as a dictionary.
+        """
+        # pylint:disable=import-outside-toplevel
+        from tvm.contrib.debugger.debug_executor import create
+
+        # pylint:enable=import-outside-toplevel
+
+        try:
+            mod = create(graph, rt_mod, dev)
+            for input_name, input_value in input_data.items():
+                mod.set_input(input_name, input_value)
+            graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]]
+            graph_time = mod.run_individual(number=10, repeat=1, 
min_repeat_ms=5000)
+
+            print("Running time of each layer:")
+            print("---------------------------")
+            print("|graph_nodes| = ", len(graph_nodes))
+            print("|graph_time| = ", len(graph_time))
+
+            for k, v in zip(graph_nodes, graph_time):
+                print(k, float(v) * 1e6, "us")
+        except Exception as exc:  # pylint: disable=broad-except
+            print(
+                f"Run module f_time_per_layer via RPC failed, exception: 
{exc}",
+            )
+
+    return f_time_per_layer
diff --git a/python/tvm/meta_schedule/testing/utils.py 
b/python/tvm/meta_schedule/testing/utils.py
index bdd3852e40..0d011b7264 100644
--- a/python/tvm/meta_schedule/testing/utils.py
+++ b/python/tvm/meta_schedule/testing/utils.py
@@ -16,13 +16,12 @@
 # under the License.
 """Testing utility functions in meta schedule"""
 from typing import Callable, Dict, Optional, Union
-
-from tvm import meta_schedule as ms
 from tvm.ir import IRModule
 from tvm.relay import Function as RelayFunc
 from tvm.runtime import NDArray
 from tvm.target import Target
 from tvm.tir import Schedule
+from tvm import meta_schedule as ms
 
 
 def apply_fixed_schedules(
diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py
index fabf14ab23..cd40429d16 100644
--- a/python/tvm/meta_schedule/tune.py
+++ b/python/tvm/meta_schedule/tune.py
@@ -24,7 +24,7 @@ from typing import Any, Callable, Dict, List, NamedTuple, 
Optional, Union
 
 from tvm.ir import IRModule
 from tvm.ir.transform import PassContext
-from tvm.runtime import Module, NDArray
+from tvm.runtime import Module, NDArray, vm
 from tvm.target import Target
 from tvm.te import Tensor, create_prim_func
 from tvm.tir import PrimFunc, Schedule
@@ -346,8 +346,9 @@ def tune_extracted_tasks(
         cost_model=cost_model,
         measure_callbacks=measure_callbacks,
     )
-    task_scheduler.tune()
-    cost_model.save(osp.join(work_dir, "cost_model.xgb"))
+    if config.max_trials_global > 0:
+        task_scheduler.tune()
+        cost_model.save(osp.join(work_dir, "cost_model.xgb"))
     return database
 
 
@@ -516,6 +517,7 @@ def tune_relay(
     config: TuneConfig,
     work_dir: str,
     *,
+    backend: str = "graph",
     params: Optional[Dict[str, NDArray]] = None,
     builder: Optional[Builder] = None,
     runner: Optional[Runner] = None,
@@ -527,7 +529,7 @@ def tune_relay(
     postprocs: Optional[FnPostproc] = None,
     mutator_probs: Optional[FnMutatorProb] = None,
     num_threads: Optional[int] = None,
-) -> Module:
+) -> Union[Module, vm.Executable]:
     """Tune a TIR IRModule with a given target.
 
     Parameters
@@ -552,15 +554,16 @@ def tune_relay(
         The database to use.
     measure_callbacks : Optional[List[MeasureCallback]]
         The callbacks used during tuning.
+    backend : str = "graph"
+        The backend to use for relay compilation(graph / vm).
 
     Returns
     -------
-    lib : Module
-        The built runtime module for the given relay workload.
+    lib : Union[Module, tvm.runtime.vm.Executable]
+        The built runtime module or vm Executable for the given relay workload.
     """
     # pylint: disable=import-outside-toplevel
-    from tvm.relay import build as relay_build
-
+    from tvm import relay
     from .relay_integration import extract_task_from_relay
 
     # pylint: disable=protected-access, enable=import-outside-toplevel
@@ -584,6 +587,7 @@ def tune_relay(
         mutator_probs=mutator_probs,
         num_threads=num_threads,
     )
+    relay_build = {"graph": relay.build, "vm": relay.vm.compile}[backend]
     with Profiler.timeit("ApplyHistoryBest"):
         with target, autotvm_silencer(), ApplyHistoryBest(database):
             with PassContext(

Reply via email to