This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0c1aad78f9 [Testing] Add tvm.testing.local_run (#15268)
0c1aad78f9 is described below
commit 0c1aad78f9ef6071cc383905f11a37fc29e1a0a6
Author: Junru Shao <[email protected]>
AuthorDate: Sat Jul 8 19:13:33 2023 -0700
[Testing] Add tvm.testing.local_run (#15268)
This PR introduces `tvm.testing.local_run`, which serves as a convenient
numpy-in numpy-out interface to quickly run a `runtime.Module` in TVM
and obtain its running time and outputs.
Example:
```python
@I.ir_module
class Module:
...
n = 128
np_a = np.random.uniform(-1, 1, [1, 32, 1, 128]).astype(np.float16)
np_b = np.random.uniform(-1, 1, [1, 32, n, 128]).astype(np.float16)
np_c = np.random.uniform(-1, 1, [1, 1, 1, n]).astype(np.float16)
np_d = np.random.uniform(-1, 1, [1, 32, 1, n]).astype(np.float32)
_, _, _, np_d = local_run(
tvm.build(Module, target="llvm"),
device_type="cpu",
args=[np_a, np_b, np_c, np_d],
)
```
---
python/tvm/testing/__init__.py | 2 +-
python/tvm/testing/{rpc_run.py => runner.py} | 85 +++++++++++++++++++++++++---
2 files changed, 78 insertions(+), 9 deletions(-)
diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py
index 08c0926277..3e5f838a27 100644
--- a/python/tvm/testing/__init__.py
+++ b/python/tvm/testing/__init__.py
@@ -43,5 +43,5 @@ from .popen_pool import (
slow_summation,
timeout_job,
)
-from .rpc_run import rpc_run
+from .runner import local_run, rpc_run
from .utils import *
diff --git a/python/tvm/testing/rpc_run.py b/python/tvm/testing/runner.py
similarity index 66%
rename from python/tvm/testing/rpc_run.py
rename to python/tvm/testing/runner.py
index 08c00ca4d1..5b677df4bd 100644
--- a/python/tvm/testing/rpc_run.py
+++ b/python/tvm/testing/runner.py
@@ -22,16 +22,14 @@ from typing_extensions import Literal
if TYPE_CHECKING:
import numpy as np
-
from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
from tvm.runtime import Device, Module, NDArray
# pylint: disable=import-outside-toplevel,protected-access
-def _args_to_remote(args, device):
+def _args_to_device(args, device):
import numpy as np
-
from tvm.runtime.ndarray import NDArray, empty
uploaded_args = []
@@ -45,7 +43,7 @@ def _args_to_remote(args, device):
return uploaded_args
-def _args_to_local(args):
+def _args_to_numpy(args):
from tvm.runtime.ndarray import NDArray
downloaded_args = []
@@ -77,6 +75,77 @@ def _normalize_export_func(export_func, output_format) ->
Tuple[Callable, str]:
return export_func, output_format
+def local_run( # pylint: disable=too-many-arguments,too-many-locals
+ mod: "Module",
+ device_type: str,
+ args: List[Union["np.ndarray", "NDArray", int, float]],
+ evaluator_config: Optional["EvaluatorConfig"] = None,
+ export_func: Union[Callable[["Module", str], None], Literal["tar", "ndk"]]
= "tar",
+ output_format: Optional[str] = None,
+):
+ """Run a TVM module on a local device.
+
+ Parameters
+ ----------
+ mod : Module
+ The TVM module to run.
+ device_type : str
+ The device type to run the module on.
+ args : List[Union[np.ndarray, NDArray, int, float]]
+ The arguments to be fed to the module.
+ evaluator_config : Optional[EvaluatorConfig]
+ The evaluator configuration to use.
+ export_func : Union[Callable[Module, str], Literal["tar", "ndk"]]
+ The function to export the module to a file.
+ If callable, it must be a function that takes two arguments: the
module to export and the
+ path to export to.
+ If "tar", the module will be exported to a tar file.
+ If "ndk", the module will be exported to a shared library.
+ output_format : Optional[str]
+ The format of the exported module.
+ If not specified, it will be inferred from the `export_func` argument.
+
+ Returns
+ -------
+ args : List[Union[np.ndarray, NDArray, int, float]]
+ The results of running the module.
+ """
+ import os.path as osp
+ import tempfile
+
+ from tvm.meta_schedule.runner import EvaluatorConfig
+ from tvm.runtime import device, load_module
+
+ evaluator_config = EvaluatorConfig._normalized(evaluator_config)
+ export_func, output_format = _normalize_export_func(export_func,
output_format)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ artifact_path = osp.join(tmp_dir, "tvm_tmp_mod." + output_format)
+ export_func(mod, artifact_path)
+ device: Device = device(device_type, 0)
+
+ try:
+ args = _args_to_device(args, device)
+ remote_mod = load_module(artifact_path)
+ profile_result = remote_mod.time_evaluator(
+ func_name=remote_mod.entry_name,
+ dev=device,
+ number=evaluator_config.number,
+ repeat=evaluator_config.repeat,
+ min_repeat_ms=evaluator_config.min_repeat_ms,
+ f_preproc="cache_flush_cpu_non_first_arg"
+ if evaluator_config.enable_cpu_cache_flush
+ else "",
+ )(*args)
+ print(profile_result)
+ remote_mod(*args)
+ args = _args_to_numpy(args)
+ finally:
+ pass
+
+ return args
+
+
def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
mod: "Module",
device_type: str,
@@ -103,7 +172,7 @@ def rpc_run( # pylint:
disable=too-many-arguments,too-many-locals
If not specified, the default RPC configuration will be used, which
reads the following
environment variables:
- TVM_TRACKER_HOST
- - TVM_TRACKER_PORmod
+ - TVM_TRACKER_PORT
- TVM_TRACKER_KEY
export_func : Union[Callable[Module, str], Literal["tar", "ndk"]]
The function to export the module to a file.
@@ -134,12 +203,12 @@ def rpc_run( # pylint:
disable=too-many-arguments,too-many-locals
artifact_path = osp.join(tmp_dir, "tvm_tmp_mod." + output_format)
_, remote_path = osp.split(artifact_path)
session = rpc_config.connect_server()
- device: Device = session.device(dev_type=device_type, dev_id=0)
+ device: Device = session.device(device_type, 0)
export_func(mod, artifact_path)
try:
session.upload(artifact_path, remote_path)
- args = _args_to_remote(args, device)
+ args = _args_to_device(args, device)
remote_mod = session.load_module(remote_path)
profile_result = remote_mod.time_evaluator(
func_name=remote_mod.entry_name,
@@ -153,7 +222,7 @@ def rpc_run( # pylint:
disable=too-many-arguments,too-many-locals
)(*args)
print(profile_result)
remote_mod(*args)
- args = _args_to_local(args)
+ args = _args_to_numpy(args)
finally:
session.remove(remote_path)
session.remove(remote_path + "." + output_format)