zxybazh opened a new pull request, #15322:
URL: https://github.com/apache/tvm/pull/15322
This PR introduces benchmarking tools for dynamic shape PrimFuncs and Relax.
It facilitates the following functionalities:
1. Automatically benchmark PrimFunc performance with user-speicified dynamic
shape sampling function and input information. E.g., `n=random.randint(50,
100)` will benchmark performance of PrimFunc with Dynamic dimension `n` between
50-100.
2. Extract self-contained PrimFunc benchmarking files from Relax Module.
E.g., it can produce multiple python files for each function in a Relax Module
and automatically extract dynamic shape input information from bindings.
3. Conduct Relax Function level benchmarking using the same dynamic shape
sample value across the Relax Function. E.g., the same value `n` is used
consistently in a Relax Function for any PrimFunction call so we can figure out
what is the performance bottleneck when we pin down the value of all dynamic
dimensions like `n` or `m` (Yes, there could be multiple dynamic dimensions).
4. It can automatically generate valid input even when the input is not the
dynamic shape but the value of the dynamic dimension, which is `n`. This is
specific to `rotatry_embedding` for now.
Co-authored by @cyx-6
Thanks to @junrushao for advices.
Example Usage
```python
def benchmark_prim_func_full_rpc():
with LocalRPC() as rpc:
rpc_config = ms.runner.RPCConfig(
tracker_host=rpc.tracker_host,
tracker_port=rpc.tracker_port,
tracker_key=rpc.tracker_key,
session_priority=1,
session_timeout_sec=100,
)
benchmark_prim_func(
cuda_workload,
args=[
((1, "m", 4096), "float32"),
((4096, 4096), "float32"),
((1, "m", 4096), "float32"),
],
dym_var_dict={"m": "int32"},
target="nvidia/geforce-rtx-3070",
dev=tvm.cuda(),
rpc_config=rpc_config,
evaluator_config=ms.runner.EvaluatorConfig(
number=10,
repeat=10,
min_repeat_ms=0,
enable_cpu_cache_flush=False,
),
)
```
Expected Results for the tested PrimFunc:
```
InputInfo Time(us) Std(us) Weight WxTime(ms)
0 m = 126 752.48000 19.417429 1 0.752480
1 m = 56 430.68955 0.244274 1 0.430690
2 m = 13 340.11350 0.241286 1 0.340114
3 m = 89 692.58875 0.343988 1 0.692589
4 m = 98 819.43990 0.316655 1 0.819440
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]