This is an automated email from the ASF dual-hosted git repository.

xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5dc4186338 [MetaSchedule] Add JSON Database Validation Scripts (#12948)
5dc4186338 is described below

commit 5dc418633839d112c5b7519111d5745d365e941e
Author: Xiyou Zhou <[email protected]>
AuthorDate: Wed Nov 9 14:42:48 2022 -0800

    [MetaSchedule] Add JSON Database Validation Scripts (#12948)
    
    * Add validation scripts.
    
    * Fix testing script.
    
    * Fix lint.
    
    * Fix lint.
    
    * Fix inputs.
    
    * Fix lint.
    
    * Fix lint.
    
    * Add timer func.
    
    * Fix ci.
    
    * Address comments.
    
    * Add total time statistics.
    
    * Fix lint.
---
 python/tvm/meta_schedule/profiler.py               |   2 +-
 .../meta_schedule/testing/custom_builder_runner.py |   4 +-
 python/tvm/meta_schedule/testing/tune_utils.py     |  55 +++-
 .../tvm/meta_schedule/testing/validate_database.py | 282 +++++++++++++++++++++
 4 files changed, 336 insertions(+), 7 deletions(-)

diff --git a/python/tvm/meta_schedule/profiler.py 
b/python/tvm/meta_schedule/profiler.py
index 7446578a38..1776666f4e 100644
--- a/python/tvm/meta_schedule/profiler.py
+++ b/python/tvm/meta_schedule/profiler.py
@@ -34,7 +34,7 @@ class Profiler(Object):
         )
 
     def get(self) -> Dict[str, float]:
-        """Get the profiling results in minutes"""
+        """Get the profiling results in seconds"""
         return _ffi_api.ProfilerGet(self)  # type: ignore # pylint: 
disable=no-member
 
     def table(self) -> str:
diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py 
b/python/tvm/meta_schedule/testing/custom_builder_runner.py
index 1cfd4ab833..7129546dd8 100644
--- a/python/tvm/meta_schedule/testing/custom_builder_runner.py
+++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py
@@ -17,7 +17,7 @@
 """Customized builder and runner methods"""
 # pylint: disable=import-outside-toplevel
 
-from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Union, Callable
 
 if TYPE_CHECKING:
     import numpy as np  # type: ignore
@@ -143,7 +143,7 @@ def run_module_via_rpc(
     rpc_config: "RPCConfig",
     lib: Union["Module", "Executable"],
     dev_type: str,
-    args: Dict[str, "np.ndarray"],
+    args: Union[Dict[int, "np.ndarray"], Dict[str, "np.ndarray"]],
     continuation: Callable,
     backend: Optional[str] = "graph",
 ):
diff --git a/python/tvm/meta_schedule/testing/tune_utils.py 
b/python/tvm/meta_schedule/testing/tune_utils.py
index fe0984d51c..17064c64ab 100644
--- a/python/tvm/meta_schedule/testing/tune_utils.py
+++ b/python/tvm/meta_schedule/testing/tune_utils.py
@@ -86,7 +86,7 @@ def create_timer(backend: str) -> Callable:
 
     def f_timer(
         rt_mod: Union[tvm.runtime.Module, tvm.runtime.vm.Executable],
-        dev: tvm.device,
+        dev: tvm.runtime.Device,
         input_data: Dict[str, NDArray],
     ) -> None:
         """Run and benchmark the given runtime module, print out the result.
@@ -95,7 +95,7 @@ def create_timer(backend: str) -> Callable:
         ----------
         rt_mod : Union[tvm.runtime.Module, tvm.runtime.vm.Executable]
             The runtime module or vm executable.
-        dev : tvm.device
+        dev : tvm.runtime.Device
             The device type to run workload.
         input_data : Dict[str, np.ndarray]
             The input data as a dictionary.
@@ -152,7 +152,7 @@ def create_time_per_layer(graph: str) -> Callable:
 
     def f_time_per_layer(
         rt_mod: tvm.runtime.Module,
-        dev: tvm.device,
+        dev: tvm.runtime.Device,
         input_data: Dict[str, NDArray],
     ) -> None:
         """Run and benchmark the per-layer performance of given runtime module,
@@ -162,7 +162,7 @@ def create_time_per_layer(graph: str) -> Callable:
         ----------
         rt_mod : tvm.runtime.Module
             The runtime module.
-        dev : tvm.device
+        dev : tvm.runtime.Device
             The device type to run workload.
         input_data : Dict[str, np.ndarray]
             The input data as a dictionary.
@@ -192,3 +192,50 @@ def create_time_per_layer(graph: str) -> Callable:
             )
 
     return f_time_per_layer
+
+
+def create_calculator(backend: str) -> Callable:
+    """Create a function to fetch the computing result of running the given 
runtime module.
+
+    Parameters
+    ----------
+    backend : str
+        The backend to use, only tir is supported for now.
+
+    Returns
+    -------
+    func : Callable
+        The function to fetch the computing result.
+    """
+
+    def f_calculator(
+        rt_mod: tvm.runtime.Module,
+        dev: tvm.runtime.Device,  # pylint: disable=unused-argument
+        input_data: Dict[str, NDArray],
+    ) -> List[NDArray]:
+        """Fetch the result of running the given runtime module.
+
+        Parameters
+        ----------
+        rt_mod : Union[tvm.runtime.Module, tvm.runtime.vm.Executable]
+            The runtime module or vm executable.
+        dev : tvm.device
+            The device type to run workload.
+        input_data : Dict[str, np.ndarray]
+            The input data as a dictionary.
+        """
+        try:
+            if backend == "tir":
+                data = [v for _, v in sorted(input_data.items(), key=lambda x: 
x[0])]
+                rt_mod(*data)
+                return data
+            else:
+                raise ValueError(f"Backend {backend} not supported in 
f_calculator!")
+
+        except Exception as exc:  # pylint: disable=broad-except
+            print(
+                f"Run module f_calculator via RPC failed, exception: {exc}",
+            )
+            return None
+
+    return f_calculator
diff --git a/python/tvm/meta_schedule/testing/validate_database.py 
b/python/tvm/meta_schedule/testing/validate_database.py
new file mode 100644
index 0000000000..5e48bfb6b0
--- /dev/null
+++ b/python/tvm/meta_schedule/testing/validate_database.py
@@ -0,0 +1,282 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""JSON Database validation script"""
+from typing import Union, Callable, List
+from distutils.util import strtobool
+import argparse
+import logging
+import warnings
+import numpy as np  # type: ignore
+
+import tvm
+from tvm.target import Target
+from tvm.ir import IRModule
+from tvm.tir import Schedule
+from tvm import meta_schedule as ms
+from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
+from tvm.meta_schedule.testing.tune_utils import create_calculator, 
generate_input_data
+from tvm._ffi import get_global_func, register_func
+from tvm.support import describe
+
+DELIMITOR = "\n" + "-" * 30 + "\n"
+
+
+def _parse_args():
+    args = argparse.ArgumentParser()
+    args.add_argument(
+        "--work-dir",
+        type=str,
+        required=True,
+        help="The path to the work directory containing database files.",
+    )
+    args.add_argument(
+        "--target",
+        type=Target,
+        required=True,
+    )
+    args.add_argument(
+        "--baseline-target",
+        type=Target,
+        default="llvm -num-cores=1",
+        required=False,
+        help="The baseline target to compile the original module.",
+    )
+    args.add_argument(
+        "--rpc-host",
+        type=str,
+        required=True,
+    )
+    args.add_argument(
+        "--rpc-port",
+        type=int,
+        required=True,
+    )
+    args.add_argument(
+        "--rpc-key",
+        type=str,
+        required=True,
+    )
+    args.add_argument(
+        "--number",
+        type=int,
+        default=3,
+    )
+    args.add_argument(
+        "--repeat",
+        type=int,
+        default=1,
+    )
+    args.add_argument(
+        "--min-repeat-ms",
+        type=int,
+        default=100,
+    )
+    args.add_argument(
+        "--cpu-flush",
+        type=lambda x: bool(strtobool(x)),
+        help="example: True / False",
+        required=True,
+    )
+    parsed = args.parse_args()
+    parsed.target = tvm.target.Target(parsed.target)
+    parsed.rpc_config = ms.runner.RPCConfig(
+        tracker_host=parsed.rpc_host,
+        tracker_port=parsed.rpc_port,
+        tracker_key=parsed.rpc_key,
+        session_timeout_sec=600,
+    )
+    if parsed.cpu_flush and parsed.target.kind.name != "llvm":
+        warnings.warn("cpu_flush is only supported on llvm target")
+    return parsed
+
+
+# logging
+logging.basicConfig(
+    format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s", 
datefmt="%Y-%m-%d %H:%M:%S"
+)
+logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
+
+# arg parser
+ARGS = _parse_args()
+
+
+@register_func("tvm.meta_schedule.testing.default_input_generator")
+def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]:
+    args_info = ms.arg_info.TensorInfo.from_prim_func(mod["main"])
+    inputs = [
+        tvm.nd.array(generate_input_data(input_shape=arg_info.shape, 
input_dtype=arg_info.dtype))
+        for arg_info in args_info
+    ]
+    return inputs
+
+
+@register_func("tvm.meta_schedule.testing.default_check_metric")
+def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> 
bool:
+    assert len(a) == len(b), "Different number of outputs from two modules"
+    for i, _ in enumerate(a):
+        if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3):
+            return False
+    return True
+
+
+def validate_correctness(
+    original_mod: IRModule,  # compiled for "baseline_target"
+    scheduled_mod: IRModule,  # compiled for "target"
+    *,
+    baseline_target: Target,
+    target: Target,
+    dev_type: str,
+    rpc_config: ms.runner.RPCConfig,
+    f_input_generator: Union[
+        str, Callable[[IRModule], List[tvm.nd.NDArray]]
+    ] = default_input_generator,
+    f_check_metric: Union[
+        str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool]
+    ] = default_check_metric,
+) -> bool:
+    """Function to validate the correctness of a scheduled module.
+
+    Parameters
+    ----------
+    original_mod : IRModule
+        The original module to be compiled.
+    scheduled_mod : IRModule
+        The scheduled module to be compiled.
+    baseline_target : Target
+        The baseline target to compile the original module.
+    target : Target
+        The target to compile the scheduled module.
+    dev_type : str
+        The device type to run the module via rpc.
+    rpc_config : RPCConfig
+        The RPCConfig to run the scheduled module.
+    f_input_generator : Union[str, Callable]
+        The function to generate the input data.
+    f_check_metric : Union[str, Callable]
+        The function to check the metric.
+
+    Returns
+    -------
+    result : bool
+        The result of the validation.
+    """
+
+    def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
+        """Convert a list of TVM NDArray to a list of numpy array"""
+        assert a is not None, "Empty result cannot be converted to numpy"
+        return [x.numpy() for x in a]
+
+    def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
+        """Convert a list of numpy array to a list of TVM NDArray"""
+        assert a is not None, "Empty result cannot be converted to TVM NDArray"
+        return [tvm.nd.array(x) for x in a]
+
+    def build_and_run(mod: IRModule, target: Target, dev_type: str) -> 
np.ndarray:
+        """Build and run the module on the target device."""
+        rt_mod = tvm.build(mod, target=target)
+        return run_module_via_rpc(
+            rpc_config=rpc_config,
+            lib=rt_mod,
+            dev_type=dev_type,
+            args={i: v for i, v in enumerate(inputs)},  # pylint: 
disable=unnecessary-comprehension
+            continuation=create_calculator(backend="tir"),
+            backend="tir",
+        )
+
+    # fetch functions & prepare inputs
+    if isinstance(f_input_generator, str):
+        f_input_generator = get_global_func(f_input_generator)
+    if isinstance(f_check_metric, str):
+        f_check_metric = get_global_func(f_check_metric)
+    inputs = to_numpy(f_input_generator(original_mod))  # type: ignore
+    # build & run original result
+    original_res = to_numpy(build_and_run(original_mod, 
target=baseline_target, dev_type="cpu"))
+    scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, 
dev_type=dev_type))
+    # check metric
+    if f_check_metric(to_tvm_ndarray(original_res), 
to_tvm_ndarray(scheduled_res)):  # type: ignore
+        return True
+    else:
+        print(
+            ("\n\n").join(
+                [
+                    "Validation failed!",
+                    "Original Result:" + DELIMITOR + str(original_res),
+                    "Scheduled Result:" + DELIMITOR + str(scheduled_res),
+                    "Input:" + DELIMITOR + str(inputs),
+                    "Original IRModule:" + DELIMITOR + original_mod.script(),
+                    "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(),
+                ]
+            )
+        )
+        return False
+
+
+def main():
+    """Main function"""
+    describe()
+    database = ms.database.create(work_dir=ARGS.work_dir)
+    target = ARGS.target
+    if target.kind.name == "llvm":
+        dev_type = "cpu"
+    elif target.kind.name == "cuda":
+        dev_type = "cuda"
+    else:
+        raise RuntimeError(f"Unsupported target kind: {target.kind.name}")
+    records = database.get_all_tuning_records()
+    with ms.Profiler() as profiler:
+        for i, record in enumerate(records):
+            scope_name = f"validate #{i}"
+            with profiler.timeit(scope_name):
+                original_mod = record.workload.mod
+                sch = Schedule(original_mod)
+                record.trace.apply_to_schedule(sch=sch, remove_postproc=False)
+                scheduled_mod = sch.mod
+                is_success = False
+                try:
+                    is_success = validate_correctness(
+                        original_mod=original_mod,
+                        scheduled_mod=scheduled_mod,
+                        target=target,
+                        baseline_target=ARGS.baseline_target,
+                        dev_type=dev_type,
+                        rpc_config=ARGS.rpc_config,
+                    )
+                except Exception as e:  # pylint: disable=broad-except, 
invalid-name
+                    print(
+                        ("\n\n").join(
+                            [
+                                "Validation failed!",
+                                "Original IRModule:" + DELIMITOR + 
original_mod.script(),
+                                "Scheduled IRModule:" + DELIMITOR + 
scheduled_mod.script(),
+                                "Exception" + DELIMITOR + str(e),
+                            ]
+                        )
+                    )
+            if is_success:
+                print(
+                    f"Progress {i+1: 6d} / {len(records): 6d} checked,"
+                    f" used {float(profiler.get()[scope_name]): 3.3f} sec."
+                )
+            else:
+                return
+
+    print("Validation passed!")
+    print(f"Total time spent: {float(profiler.get()['Total']): 3.3f} sec.")
+
+
+if __name__ == "__main__":
+    main()

Reply via email to