This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 270c793574 Extend gpu memory bandwidth test to work through RPC 
(#16608)
270c793574 is described below

commit 270c7935742cef2f1a45463de181ef37c9a69ab5
Author: Egor Churaev <[email protected]>
AuthorDate: Tue Feb 20 04:55:25 2024 +0300

    Extend gpu memory bandwidth test to work through RPC (#16608)
    
    It is possible to measure GPU bandwidth through RPC on an Android
    device.
---
 python/tvm/exec/gpu_memory_bandwidth.py | 98 ++++++++++++++++++++++++++++-----
 1 file changed, 83 insertions(+), 15 deletions(-)

diff --git a/python/tvm/exec/gpu_memory_bandwidth.py 
b/python/tvm/exec/gpu_memory_bandwidth.py
index a5f2021f73..4fdc6aef2f 100644
--- a/python/tvm/exec/gpu_memory_bandwidth.py
+++ b/python/tvm/exec/gpu_memory_bandwidth.py
@@ -22,8 +22,8 @@ import numpy as np
 
 import tvm
 from tvm import te, tir
-from tvm.meta_schedule.runner import EvaluatorConfig
-from tvm.testing import local_run
+from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
+from tvm.testing import local_run, rpc_run
 
 
 def _parse_args() -> argparse.Namespace:
@@ -32,12 +32,23 @@ def _parse_args() -> argparse.Namespace:
 
     parser = argparse.ArgumentParser(
         prog="GPU memory bandwidth testing",
-        description="""Example:
+        description="""Example for host GPU:
     python -m tvm.exec.gpu_memory_bandwidth "nvidia/geforce-rtx-3090-ti" \
         --dtype "float32"
         --bx "8,16,32,64,128,256"      \
         --tx "32,64,128,256,512,1024"  \
-        --vec "1,2,4"
+        --vec "1,2,4" \
+
+    Example for Android GPU: \
+    python -m tvm.exec.gpu_memory_bandwidth "opencl" --target_host "llvm 
-mtriple=arm64-linux-android" \
+        --rpc_host "127.0.0.1" \
+        --rpc_port 9190 \
+        --rpc_key "android" \
+        --export_func "ndk" \
+        --dtype "float32" \
+        --bx "8,16,32,64,128,256"      \
+        --tx "32,64,128,256,512,1024"  \
+        --vec "1,2,4" \
 """,
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
@@ -46,6 +57,12 @@ def _parse_args() -> argparse.Namespace:
         type=str,
         help="The target to be benchmarked",
     )
+    parser.add_argument(
+        "--target_host",
+        type=str,
+        default=None,
+        help="The target host for build",
+    )
     parser.add_argument(
         "--xo",
         type=int,
@@ -88,6 +105,31 @@ def _parse_args() -> argparse.Namespace:
         default=[1, 2, 4],
         help="Vector length to be used in vectorized load",
     )
+    parser.add_argument(
+        "--rpc_host",
+        type=str,
+        default=None,
+        help="The address of RPC host (default: None, that means that RPC is 
not used)",
+    )
+    parser.add_argument(
+        "--rpc_port",
+        type=int,
+        default=None,
+        help="The port of RPC connection (default: None, that means that RPC 
is not used)",
+    )
+    parser.add_argument(
+        "--rpc_key",
+        type=str,
+        default=None,
+        help="The device key in RPC tracker (default: None, that means that 
RPC is not used)",
+    )
+    parser.add_argument(
+        "--export_func",
+        type=str,
+        default="tar",
+        help="Export function, actual only for RPC",
+        choices=["tar", "ndk"],
+    )
     return parser.parse_args()
 
 
@@ -136,7 +178,18 @@ def main():  # pylint: disable=too-many-locals
     args = _parse_args()
     # pylint: disable=invalid-name
     target = tvm.target.Target(args.target)
+    if args.target_host is not None:
+        target = tvm.target.Target(args.target, host=args.target_host)
     dtype = args.dtype
+    rpcConfig = None
+    if args.rpc_host is not None and args.rpc_port is not None and 
args.rpc_key is not None:
+        rpcConfig = RPCConfig(
+            tracker_host=args.rpc_host,
+            tracker_port=args.rpc_port,
+            tracker_key=args.rpc_key,
+            session_priority=1,
+            session_timeout_sec=600,
+        )
 
     a = np.random.uniform(-1, 1, (args.xo, args.k, args.xi)).astype(dtype)
     b = np.zeros((args.xo, args.xi), dtype=dtype)
@@ -165,17 +218,32 @@ def main():  # pylint: disable=too-many-locals
         sch = tir.Schedule(func)
         _schedule(sch, len_bx, len_tx, len_vec)
 
-        _, profile_result = local_run(
-            tvm.build(sch.mod, target=target),
-            target.kind.name,
-            [a, b],
-            evaluator_config=EvaluatorConfig(
-                number=10,
-                repeat=1,
-                min_repeat_ms=100,
-                enable_cpu_cache_flush=False,
-            ),
-        )
+        if rpcConfig is None:
+            _, profile_result = local_run(
+                tvm.build(sch.mod, target=target),
+                target.kind.name,
+                [a, b],
+                evaluator_config=EvaluatorConfig(
+                    number=10,
+                    repeat=1,
+                    min_repeat_ms=100,
+                    enable_cpu_cache_flush=False,
+                ),
+            )
+        else:
+            _, profile_result = rpc_run(
+                tvm.build(sch.mod, target=target),
+                target.kind.name,
+                [a, b],
+                evaluator_config=EvaluatorConfig(
+                    number=10,
+                    repeat=1,
+                    min_repeat_ms=100,
+                    enable_cpu_cache_flush=False,
+                ),
+                rpc_config=rpcConfig,
+                export_func=args.export_func,
+            )
         bandwidth = num_bytes / profile_result.mean / (1024**3)
         bx = len_bx * args.xi // (len_tx * len_vec)  # pylint: 
disable=invalid-name
         mbs = num_bytes / 1024 / 1024

Reply via email to