This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 270c793574 Extend gpu memory bandwidth test to work through RPC
(#16608)
270c793574 is described below
commit 270c7935742cef2f1a45463de181ef37c9a69ab5
Author: Egor Churaev <[email protected]>
AuthorDate: Tue Feb 20 04:55:25 2024 +0300
Extend gpu memory bandwidth test to work through RPC (#16608)
It is possible to measure GPU bandwidth through RPC on an Android
device.
---
python/tvm/exec/gpu_memory_bandwidth.py | 98 ++++++++++++++++++++++++++++-----
1 file changed, 83 insertions(+), 15 deletions(-)
diff --git a/python/tvm/exec/gpu_memory_bandwidth.py
b/python/tvm/exec/gpu_memory_bandwidth.py
index a5f2021f73..4fdc6aef2f 100644
--- a/python/tvm/exec/gpu_memory_bandwidth.py
+++ b/python/tvm/exec/gpu_memory_bandwidth.py
@@ -22,8 +22,8 @@ import numpy as np
import tvm
from tvm import te, tir
-from tvm.meta_schedule.runner import EvaluatorConfig
-from tvm.testing import local_run
+from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
+from tvm.testing import local_run, rpc_run
def _parse_args() -> argparse.Namespace:
@@ -32,12 +32,23 @@ def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
prog="GPU memory bandwidth testing",
- description="""Example:
+ description="""Example for host GPU:
python -m tvm.exec.gpu_memory_bandwidth "nvidia/geforce-rtx-3090-ti" \
--dtype "float32"
--bx "8,16,32,64,128,256" \
--tx "32,64,128,256,512,1024" \
- --vec "1,2,4"
+ --vec "1,2,4" \
+
+ Example for Android GPU: \
+ python -m tvm.exec.gpu_memory_bandwidth "opencl" --target_host "llvm
-mtriple=arm64-linux-android" \
+ --rpc_host "127.0.0.1" \
+ --rpc_port 9190 \
+ --rpc_key "android" \
+ --export_func "ndk" \
+ --dtype "float32" \
+ --bx "8,16,32,64,128,256" \
+ --tx "32,64,128,256,512,1024" \
+ --vec "1,2,4" \
""",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
@@ -46,6 +57,12 @@ def _parse_args() -> argparse.Namespace:
type=str,
help="The target to be benchmarked",
)
+ parser.add_argument(
+ "--target_host",
+ type=str,
+ default=None,
+ help="The target host for build",
+ )
parser.add_argument(
"--xo",
type=int,
@@ -88,6 +105,31 @@ def _parse_args() -> argparse.Namespace:
default=[1, 2, 4],
help="Vector length to be used in vectorized load",
)
+ parser.add_argument(
+ "--rpc_host",
+ type=str,
+ default=None,
+ help="The address of RPC host (default: None, that means that RPC is
not used)",
+ )
+ parser.add_argument(
+ "--rpc_port",
+ type=int,
+ default=None,
+ help="The port of RPC connection (default: None, that means that RPC
is not used)",
+ )
+ parser.add_argument(
+ "--rpc_key",
+ type=str,
+ default=None,
+ help="The device key in RPC tracker (default: None, that means that
RPC is not used)",
+ )
+ parser.add_argument(
+ "--export_func",
+ type=str,
+ default="tar",
+ help="Export function, actual only for RPC",
+ choices=["tar", "ndk"],
+ )
return parser.parse_args()
@@ -136,7 +178,18 @@ def main(): # pylint: disable=too-many-locals
args = _parse_args()
# pylint: disable=invalid-name
target = tvm.target.Target(args.target)
+ if args.target_host is not None:
+ target = tvm.target.Target(args.target, host=args.target_host)
dtype = args.dtype
+ rpcConfig = None
+ if args.rpc_host is not None and args.rpc_port is not None and
args.rpc_key is not None:
+ rpcConfig = RPCConfig(
+ tracker_host=args.rpc_host,
+ tracker_port=args.rpc_port,
+ tracker_key=args.rpc_key,
+ session_priority=1,
+ session_timeout_sec=600,
+ )
a = np.random.uniform(-1, 1, (args.xo, args.k, args.xi)).astype(dtype)
b = np.zeros((args.xo, args.xi), dtype=dtype)
@@ -165,17 +218,32 @@ def main(): # pylint: disable=too-many-locals
sch = tir.Schedule(func)
_schedule(sch, len_bx, len_tx, len_vec)
- _, profile_result = local_run(
- tvm.build(sch.mod, target=target),
- target.kind.name,
- [a, b],
- evaluator_config=EvaluatorConfig(
- number=10,
- repeat=1,
- min_repeat_ms=100,
- enable_cpu_cache_flush=False,
- ),
- )
+ if rpcConfig is None:
+ _, profile_result = local_run(
+ tvm.build(sch.mod, target=target),
+ target.kind.name,
+ [a, b],
+ evaluator_config=EvaluatorConfig(
+ number=10,
+ repeat=1,
+ min_repeat_ms=100,
+ enable_cpu_cache_flush=False,
+ ),
+ )
+ else:
+ _, profile_result = rpc_run(
+ tvm.build(sch.mod, target=target),
+ target.kind.name,
+ [a, b],
+ evaluator_config=EvaluatorConfig(
+ number=10,
+ repeat=1,
+ min_repeat_ms=100,
+ enable_cpu_cache_flush=False,
+ ),
+ rpc_config=rpcConfig,
+ export_func=args.export_func,
+ )
bandwidth = num_bytes / profile_result.mean / (1024**3)
bx = len_bx * args.xi // (len_tx * len_vec) # pylint:
disable=invalid-name
mbs = num_bytes / 1024 / 1024