zxybazh commented on code in PR #15466:
URL: https://github.com/apache/tvm/pull/15466#discussion_r1295213567


##########
python/tvm/dlight/benchmark/bench.py:
##########
@@ -199,28 +212,41 @@ def benchmark_prim_func(
     rpc_config : Optional["RPCConfig"]
         The RPC configuration to connect to the remote device.
         If none, will use local mode.
-    sort_by : Optional[str]
-        Sort results by this key, if None, no sorting.
-    desc : Optional[bool]
-        Whether to sort results in descending order.
+    display_config : Optional[DisplayConfig]
+        The display configuration to use.
+        If none, will use default display configuration.
     """
+    if display_config is None:
+        display_config = DisplayConfig()
     results = []
-    if dym_var_dict is None or args is None:
-        args, dym_var_dict = extract_func_info_from_prim_func(mod_or_func)
-    for _ in range(sample_number):
-        dym_var_sample = dym_var_sample_func(dym_var_dict)
-        _, median, std = benchmark(
+    if dym_vars is None or args is None:
+        if isinstance(mod_or_func, tvm.tir.PrimFunc):
+            args, dym_vars = extract_func_info_from_prim_func(mod_or_func)
+        else:
+            gvs = mod_or_func.get_global_vars()
+            assert len(gvs) == 1, "Only support one PrimFunc in IRModule"
+            args, dym_vars = 
extract_func_info_from_prim_func(mod_or_func[gvs[0]])
+    if len(dym_vars) == 0:  # static shape
+        sample_num = 1
+        dym_var_sample_func = lambda dym_vars, sample_idx, sample_num: {}
+    for sample_idx in range(sample_num):
+        dym_var_sample = dym_var_sample_func(dym_vars, sample_idx, sample_num)
+        inputs_infos, median, std = benchmark(
             mod_or_func,
             args=args,
             dym_var_sample=dym_var_sample,
             target=target,
             evaluator_config=evaluator_config,
             rpc_config=rpc_config,
         )
+        _, total_input_bytes = inputs_infos
         row = {
-            "InputInfo": ", ".join([f"{k} = {v}" for k, v in 
dym_var_sample.items()]),
+            "InputInfo": ", ".join([f"{k} = {v}" for k, v in 
dym_var_sample.items()])
+            if len(dym_vars) > 0
+            else "static",
             "Time(us)": median * 1e6,
             "Std(us)": std * 1e6,
+            "Memory(GB/s)": total_input_bytes / median / 1024**3,

Review Comment:
   Thanks for the tip! I think it makes sense to have a maximum bandwidth if 
the workload is memory bond. On the other hand, this is a rough calculation of 
throughput so it may not reflect the computation flops, we can show peak FLOPS 
of the hardware for comparison.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to