This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9710d81650 [Testing] Utility method to run TVM on remote device
(#15179)
9710d81650 is described below
commit 9710d816501fb518d82319c7337458095457c096
Author: Junru Shao <[email protected]>
AuthorDate: Thu Jun 29 16:39:35 2023 -0700
[Testing] Utility method to run TVM on remote device (#15179)
This PR introduces `tvm.testing.rpc_run`, a utility method that allows a
`runtime.Module` to run on a remote device via TVM RPC.
Example:
```python
import numpy as np
import tvm
from tvm.script import tir as T
from tvm.testnig import rpc_run
@T.prim_func
def cuda_kernel(
A: T.Buffer((128,), "float32"),
B: T.Buffer((128,), "float32"),
):
for bx in T.thread_binding(4, thread="blockIdx.x"):
for tx in T.thread_binding(32, thread="threadIdx.x"):
x = bx * 32 + tx
B[x] = A[x] + 1.0
def main():
np_a = np.random.randn(128).astype("float32")
np_b = np_a + 1.0
rt_mod = tvm.build(cuda_kernel, target="nvidia/geforce-rtx-3090-ti")
tvm_a, tvm_b = rpc_run(
rt_mod,
"cuda",
[np_a, np_b],
)
assert np.allclose(tvm_b, np_b)
```
Result:
```
Execution time summary:
mean (ms) median (ms) max (ms) min (ms) std (ms)
0.0023 0.0023 0.0023 0.0023 0.0000
```
---
python/tvm/testing/__init__.py | 37 +++++++---
python/tvm/testing/rpc_run.py | 162 +++++++++++++++++++++++++++++++++++++++++
2 files changed, 188 insertions(+), 11 deletions(-)
diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py
index d84846725e..08c0926277 100644
--- a/python/tvm/testing/__init__.py
+++ b/python/tvm/testing/__init__.py
@@ -18,15 +18,30 @@
# pylint: disable=redefined-builtin, wildcard-import
"""Utility Python functions for TVM testing"""
+from . import auto_scheduler, autotvm
+from ._ffi_api import (
+ ErrorTest,
+ FrontendTestModule,
+ device_test,
+ echo,
+ identity_cpp,
+ nop,
+ object_use_count,
+ run_check_signal,
+ test_check_eq_callback,
+ test_raise_error_callback,
+ test_wrap_callback,
+)
+from .popen_pool import (
+ after_initializer,
+ call_cpp_ffi,
+ call_cpp_py_ffi,
+ call_py_ffi,
+ fast_summation,
+ initializer,
+ register_ffi,
+ slow_summation,
+ timeout_job,
+)
+from .rpc_run import rpc_run
from .utils import *
-
-from ._ffi_api import nop, echo, device_test, run_check_signal,
object_use_count
-from ._ffi_api import test_wrap_callback, test_raise_error_callback,
test_check_eq_callback
-from ._ffi_api import ErrorTest, FrontendTestModule, identity_cpp
-
-from .popen_pool import initializer, after_initializer, register_ffi,
call_cpp_ffi
-from .popen_pool import call_py_ffi, call_cpp_py_ffi, fast_summation,
slow_summation
-from .popen_pool import timeout_job
-
-from . import auto_scheduler
-from . import autotvm
diff --git a/python/tvm/testing/rpc_run.py b/python/tvm/testing/rpc_run.py
new file mode 100644
index 0000000000..08c00ca4d1
--- /dev/null
+++ b/python/tvm/testing/rpc_run.py
@@ -0,0 +1,162 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, missing-function-docstring
+"""A utility method to run a TVM module on a remote device."""
+from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
+
+from typing_extensions import Literal
+
+if TYPE_CHECKING:
+ import numpy as np
+
+ from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
+ from tvm.runtime import Device, Module, NDArray
+
+# pylint: disable=import-outside-toplevel,protected-access
+
+
+def _args_to_remote(args, device):
+ import numpy as np
+
+ from tvm.runtime.ndarray import NDArray, empty
+
+ uploaded_args = []
+ for arg in args:
+ if isinstance(arg, (np.ndarray, NDArray)):
+ uploaded_args.append(empty(arg.shape, dtype=arg.dtype,
device=device).copyfrom(arg))
+ elif isinstance(arg, (int, float)):
+ uploaded_args.append(arg)
+ else:
+ raise ValueError(f"Unsupported input type: {type(arg)}")
+ return uploaded_args
+
+
+def _args_to_local(args):
+ from tvm.runtime.ndarray import NDArray
+
+ downloaded_args = []
+ for arg in args:
+ if isinstance(arg, NDArray):
+ downloaded_args.append(arg.numpy())
+ else:
+ downloaded_args.append(arg)
+ return downloaded_args
+
+
+def _normalize_export_func(export_func, output_format) -> Tuple[Callable, str]:
+ from tvm.contrib import ndk, tar
+
+ def export_with(func):
+ return lambda mod, path: mod.export_library(path, func)
+
+ if export_func == "tar":
+ export_func = export_with(tar.tar)
+ output_format = "tar"
+ elif export_func == "ndk":
+ export_func = export_with(ndk.create_shared)
+ output_format = "so"
+ elif callable(export_func):
+ if output_format is None:
+ raise ValueError("output_format must be specified if `export_func`
is callable")
+ else:
+ raise ValueError(f"Unsupported export_func: {export_func}")
+ return export_func, output_format
+
+
+def rpc_run( # pylint: disable=too-many-arguments,too-many-locals
+ mod: "Module",
+ device_type: str,
+ args: List[Union["np.ndarray", "NDArray", int, float]],
+ evaluator_config: Optional["EvaluatorConfig"] = None,
+ rpc_config: Optional["RPCConfig"] = None,
+ export_func: Union[Callable[["Module", str], None], Literal["tar", "ndk"]]
= "tar",
+ output_format: Optional[str] = None,
+):
+ """Run a TVM module on a remote device.
+
+ Parameters
+ ----------
+ mod : Module
+ The TVM module to run.
+ device_type : str
+ The device type to run the module on.
+ args : List[Union[np.ndarray, NDArray, int, float]]
+ The arguments to be fed to the module.
+ evaluator_config : Optional[EvaluatorConfig]
+ The evaluator configuration to use.
+ rpc_config : Optional[RPCConfig]
+ The RPC configuration to connect to the remote device.
+ If not specified, the default RPC configuration will be used, which
reads the following
+ environment variables:
+ - TVM_TRACKER_HOST
+ - TVM_TRACKER_PORmod
+ - TVM_TRACKER_KEY
+ export_func : Union[Callable[Module, str], Literal["tar", "ndk"]]
+ The function to export the module to a file.
+ If callable, it must be a function that takes two arguments: the
module to export and the
+ path to export to.
+ If "tar", the module will be exported to a tar file.
+ If "ndk", the module will be exported to a shared library.
+ output_format : Optional[str]
+ The format of the exported module.
+ If not specified, it will be inferred from the `export_func` argument.
+
+ Returns
+ -------
+ args : List[Union[np.ndarray, NDArray, int, float]]
+ The results of running the module.
+ """
+
+ import os.path as osp
+ import tempfile
+
+ from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
+
+ evaluator_config = EvaluatorConfig._normalized(evaluator_config)
+ rpc_config = RPCConfig._normalized(rpc_config)
+ export_func, output_format = _normalize_export_func(export_func,
output_format)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ artifact_path = osp.join(tmp_dir, "tvm_tmp_mod." + output_format)
+ _, remote_path = osp.split(artifact_path)
+ session = rpc_config.connect_server()
+ device: Device = session.device(dev_type=device_type, dev_id=0)
+
+ export_func(mod, artifact_path)
+ try:
+ session.upload(artifact_path, remote_path)
+ args = _args_to_remote(args, device)
+ remote_mod = session.load_module(remote_path)
+ profile_result = remote_mod.time_evaluator(
+ func_name=remote_mod.entry_name,
+ dev=device,
+ number=evaluator_config.number,
+ repeat=evaluator_config.repeat,
+ min_repeat_ms=evaluator_config.min_repeat_ms,
+ f_preproc="cache_flush_cpu_non_first_arg"
+ if evaluator_config.enable_cpu_cache_flush
+ else "",
+ )(*args)
+ print(profile_result)
+ remote_mod(*args)
+ args = _args_to_local(args)
+ finally:
+ session.remove(remote_path)
+ session.remove(remote_path + "." + output_format)
+ session.remove("")
+
+ return args