sunggg commented on code in PR #15466:
URL: https://github.com/apache/tvm/pull/15466#discussion_r1294057548


##########
python/tvm/dlight/benchmark/bench.py:
##########
@@ -199,28 +212,41 @@ def benchmark_prim_func(
     rpc_config : Optional["RPCConfig"]
         The RPC configuration to connect to the remote device.
         If none, will use local mode.
-    sort_by : Optional[str]
-        Sort results by this key, if None, no sorting.
-    desc : Optional[bool]
-        Whether to sort results in descending order.
+    display_config : Optional[DisplayConfig]
+        The display configuration to use.
+        If none, will use default display configuration.
     """
+    if display_config is None:
+        display_config = DisplayConfig()
     results = []
-    if dym_var_dict is None or args is None:
-        args, dym_var_dict = extract_func_info_from_prim_func(mod_or_func)
-    for _ in range(sample_number):
-        dym_var_sample = dym_var_sample_func(dym_var_dict)
-        _, median, std = benchmark(
+    if dym_vars is None or args is None:
+        if isinstance(mod_or_func, tvm.tir.PrimFunc):
+            args, dym_vars = extract_func_info_from_prim_func(mod_or_func)
+        else:
+            gvs = mod_or_func.get_global_vars()
+            assert len(gvs) == 1, "Only support one PrimFunc in IRModule"
+            args, dym_vars = 
extract_func_info_from_prim_func(mod_or_func[gvs[0]])
+    if len(dym_vars) == 0:  # static shape
+        sample_num = 1
+        dym_var_sample_func = lambda dym_vars, sample_idx, sample_num: {}
+    for sample_idx in range(sample_num):
+        dym_var_sample = dym_var_sample_func(dym_vars, sample_idx, sample_num)
+        inputs_infos, median, std = benchmark(
             mod_or_func,
             args=args,
             dym_var_sample=dym_var_sample,
             target=target,
             evaluator_config=evaluator_config,
             rpc_config=rpc_config,
         )
+        _, total_input_bytes = inputs_infos
         row = {
-            "InputInfo": ", ".join([f"{k} = {v}" for k, v in 
dym_var_sample.items()]),
+            "InputInfo": ", ".join([f"{k} = {v}" for k, v in 
dym_var_sample.items()])
+            if len(dym_vars) > 0
+            else "static",
             "Time(us)": median * 1e6,
             "Std(us)": std * 1e6,
+            "Memory(GB/s)": total_input_bytes / median / 1024**3,

Review Comment:
   One potential future work item is to show the HW maximum bandwidth. This 
would be helpful to understand the gap from the theoretical upper bound. 



##########
python/tvm/dlight/benchmark/bench.py:
##########
@@ -75,8 +79,10 @@ def benchmark(
 
     Returns
     -------
-    input_infos : List[Tuple[Tuple[int, ...], str]]
-        The input tensor information, including shape and dtype.
+    input_infos : Tuple[List[Tuple[Tuple[int, ...], str]], int]
+        The input tensor information, including shape and dtype, and the total 
input bytes.
+        E.g., [((1, 32, 1, 64), "float16"), ((1, 32, 64, 128), "float16")], 
528384

Review Comment:
   nit: `([((1, 32, 1, 64), "float16"), ((1, 32, 64, 128), "float16")], 
528384)` to make it tuple



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to