This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0c1aad78f9 [Testing] Add tvm.testing.local_run (#15268)
0c1aad78f9 is described below

commit 0c1aad78f9ef6071cc383905f11a37fc29e1a0a6
Author: Junru Shao <[email protected]>
AuthorDate: Sat Jul 8 19:13:33 2023 -0700

    [Testing] Add tvm.testing.local_run (#15268)
    
    This PR introduces `tvm.testing.local_run`, which serves as a convenient
    numpy-in numpy-out interface to quickly run a `runtime.Module` in TVM
    and obtain its running time and outputs.
    
    Example:
    
    ```python
    
    @I.ir_module
    class Module:
      ...
    
    n = 128
    np_a = np.random.uniform(-1, 1, [1, 32, 1, 128]).astype(np.float16)
    np_b = np.random.uniform(-1, 1, [1, 32, n, 128]).astype(np.float16)
    np_c = np.random.uniform(-1, 1, [1, 1, 1, n]).astype(np.float16)
    np_d = np.random.uniform(-1, 1, [1, 32, 1, n]).astype(np.float32)
    
    _, _, _, np_d = local_run(
        tvm.build(Module, target="llvm"),
        device_type="cpu",
        args=[np_a, np_b, np_c, np_d],
    )
    ```
---
 python/tvm/testing/__init__.py               |  2 +-
 python/tvm/testing/{rpc_run.py => runner.py} | 85 +++++++++++++++++++++++++---
 2 files changed, 78 insertions(+), 9 deletions(-)

diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py
index 08c0926277..3e5f838a27 100644
--- a/python/tvm/testing/__init__.py
+++ b/python/tvm/testing/__init__.py
@@ -43,5 +43,5 @@ from .popen_pool import (
     slow_summation,
     timeout_job,
 )
-from .rpc_run import rpc_run
+from .runner import local_run, rpc_run
 from .utils import *
diff --git a/python/tvm/testing/rpc_run.py b/python/tvm/testing/runner.py
similarity index 66%
rename from python/tvm/testing/rpc_run.py
rename to python/tvm/testing/runner.py
index 08c00ca4d1..5b677df4bd 100644
--- a/python/tvm/testing/rpc_run.py
+++ b/python/tvm/testing/runner.py
@@ -22,16 +22,14 @@ from typing_extensions import Literal
 
 if TYPE_CHECKING:
     import numpy as np
-
     from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
     from tvm.runtime import Device, Module, NDArray
 
 # pylint: disable=import-outside-toplevel,protected-access
 
 
-def _args_to_remote(args, device):
+def _args_to_device(args, device):
     import numpy as np
-
     from tvm.runtime.ndarray import NDArray, empty
 
     uploaded_args = []
@@ -45,7 +43,7 @@ def _args_to_remote(args, device):
     return uploaded_args
 
 
-def _args_to_local(args):
+def _args_to_numpy(args):
     from tvm.runtime.ndarray import NDArray
 
     downloaded_args = []
@@ -77,6 +75,77 @@ def _normalize_export_func(export_func, output_format) -> 
Tuple[Callable, str]:
     return export_func, output_format
 
 
+def local_run(  # pylint: disable=too-many-arguments,too-many-locals
+    mod: "Module",
+    device_type: str,
+    args: List[Union["np.ndarray", "NDArray", int, float]],
+    evaluator_config: Optional["EvaluatorConfig"] = None,
+    export_func: Union[Callable[["Module", str], None], Literal["tar", "ndk"]] 
= "tar",
+    output_format: Optional[str] = None,
+):
+    """Run a TVM module on a local device.
+
+    Parameters
+    ----------
+    mod : Module
+        The TVM module to run.
+    device_type : str
+        The device type to run the module on.
+    args : List[Union[np.ndarray, NDArray, int, float]]
+        The arguments to be fed to the module.
+    evaluator_config : Optional[EvaluatorConfig]
+        The evaluator configuration to use.
+    export_func : Union[Callable[Module, str], Literal["tar", "ndk"]]
+        The function to export the module to a file.
+        If callable, it must be a function that takes two arguments: the 
module to export and the
+        path to export to.
+        If "tar", the module will be exported to a tar file.
+        If "ndk", the module will be exported to a shared library.
+    output_format : Optional[str]
+        The format of the exported module.
+        If not specified, it will be inferred from the `export_func` argument.
+
+    Returns
+    -------
+    args : List[Union[np.ndarray, NDArray, int, float]]
+        The results of running the module.
+    """
+    import os.path as osp
+    import tempfile
+
+    from tvm.meta_schedule.runner import EvaluatorConfig
+    from tvm.runtime import device, load_module
+
+    evaluator_config = EvaluatorConfig._normalized(evaluator_config)
+    export_func, output_format = _normalize_export_func(export_func, 
output_format)
+
+    with tempfile.TemporaryDirectory() as tmp_dir:
+        artifact_path = osp.join(tmp_dir, "tvm_tmp_mod." + output_format)
+        export_func(mod, artifact_path)
+        device: Device = device(device_type, 0)
+
+        try:
+            args = _args_to_device(args, device)
+            remote_mod = load_module(artifact_path)
+            profile_result = remote_mod.time_evaluator(
+                func_name=remote_mod.entry_name,
+                dev=device,
+                number=evaluator_config.number,
+                repeat=evaluator_config.repeat,
+                min_repeat_ms=evaluator_config.min_repeat_ms,
+                f_preproc="cache_flush_cpu_non_first_arg"
+                if evaluator_config.enable_cpu_cache_flush
+                else "",
+            )(*args)
+            print(profile_result)
+            remote_mod(*args)
+            args = _args_to_numpy(args)
+        finally:
+            pass
+
+    return args
+
+
 def rpc_run(  # pylint: disable=too-many-arguments,too-many-locals
     mod: "Module",
     device_type: str,
@@ -103,7 +172,7 @@ def rpc_run(  # pylint: 
disable=too-many-arguments,too-many-locals
         If not specified, the default RPC configuration will be used, which 
reads the following
         environment variables:
         - TVM_TRACKER_HOST
-        - TVM_TRACKER_PORmod
+        - TVM_TRACKER_PORT
         - TVM_TRACKER_KEY
     export_func : Union[Callable[Module, str], Literal["tar", "ndk"]]
         The function to export the module to a file.
@@ -134,12 +203,12 @@ def rpc_run(  # pylint: 
disable=too-many-arguments,too-many-locals
         artifact_path = osp.join(tmp_dir, "tvm_tmp_mod." + output_format)
         _, remote_path = osp.split(artifact_path)
         session = rpc_config.connect_server()
-        device: Device = session.device(dev_type=device_type, dev_id=0)
+        device: Device = session.device(device_type, 0)
 
         export_func(mod, artifact_path)
         try:
             session.upload(artifact_path, remote_path)
-            args = _args_to_remote(args, device)
+            args = _args_to_device(args, device)
             remote_mod = session.load_module(remote_path)
             profile_result = remote_mod.time_evaluator(
                 func_name=remote_mod.entry_name,
@@ -153,7 +222,7 @@ def rpc_run(  # pylint: 
disable=too-many-arguments,too-many-locals
             )(*args)
             print(profile_result)
             remote_mod(*args)
-            args = _args_to_local(args)
+            args = _args_to_numpy(args)
         finally:
             session.remove(remote_path)
             session.remove(remote_path + "." + output_format)

Reply via email to