This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0681b3959f [Dlight] Benchmarking Tools for Dynamic Shape PrimFuncs & 
Relax Function (#15322)
0681b3959f is described below

commit 0681b3959fe65abe8a73c49fa80d1c4f0e793b43
Author: Xiyou Zhou <[email protected]>
AuthorDate: Fri Jul 21 16:31:13 2023 -0700

    [Dlight] Benchmarking Tools for Dynamic Shape PrimFuncs & Relax Function 
(#15322)
    
    This PR introduces benchmarking tools for dynamic shape PrimFuncs and 
Relax. It facilitates the following functionalities:
    1. Automatically benchmark PrimFunc performance with user-speicified 
dynamic shape sampling function and input information. E.g., 
`n=random.randint(50, 100)` will benchmark performance of PrimFunc with Dynamic 
dimension `n` between 50-100.
    2. Extract self-contained PrimFunc benchmarking files from Relax Module. 
E.g., it can produce multiple python files for each function in a Relax Module 
and automatically extract dynamic shape input information from bindings.
    3. Conduct Relax Function level benchmarking using the same dynamic shape 
sample value across the Relax Function. E.g., the same value `n` is used 
consistently in a Relax Function for any PrimFunction call so we can figure out 
what is the performance bottleneck when we pin down the value of all dynamic 
dimensions like `n` or `m` (Yes, there could be multiple dynamic dimensions).
    4. It can automatically generate valid input even when the input is not the 
dynamic shape but the value of the dynamic dimension, which is `n`. This is 
specific to `rotatry_embedding` for now.
    
    Example Usage
    ```python
    def benchmark_prim_func_full_rpc():
        with LocalRPC() as rpc:
            rpc_config = ms.runner.RPCConfig(
                tracker_host=rpc.tracker_host,
                tracker_port=rpc.tracker_port,
                tracker_key=rpc.tracker_key,
                session_priority=1,
                session_timeout_sec=100,
            )
            benchmark_prim_func(
                cuda_workload,
                args=[
                    ((1, "m", 4096), "float32"),
                    ((4096, 4096), "float32"),
                    ((1, "m", 4096), "float32"),
                ],
                dym_var_dict={"m": "int32"},
                target="nvidia/geforce-rtx-3070",
                dev=tvm.cuda(),
                rpc_config=rpc_config,
                evaluator_config=ms.runner.EvaluatorConfig(
                    number=10,
                    repeat=10,
                    min_repeat_ms=0,
                    enable_cpu_cache_flush=False,
                ),
            )
    ```
    
    Expected Results for the tested PrimFunc:
    ```
      InputInfo   Time(us)    Std(us)  Weight  WxTime(ms)
    0   m = 126  752.48000  19.417429       1    0.752480
    1    m = 56  430.68955   0.244274       1    0.430690
    2    m = 13  340.11350   0.241286       1    0.340114
    3    m = 89  692.58875   0.343988       1    0.692589
    4    m = 98  819.43990   0.316655       1    0.819440
    ```
    
    
    Co-authored-by: Yaxing Cai <[email protected]>
---
 python/tvm/dlight/benchmark/__init__.py        |  24 ++
 python/tvm/dlight/benchmark/bench.py           | 312 ++++++++++++++++++++++
 python/tvm/dlight/benchmark/extract.py         | 351 +++++++++++++++++++++++++
 python/tvm/dlight/benchmark/utils.py           | 172 ++++++++++++
 python/tvm/meta_schedule/testing/tune_utils.py |   5 -
 tests/python/dlight/test_benchmark.py          | 316 ++++++++++++++++++++++
 6 files changed, 1175 insertions(+), 5 deletions(-)

diff --git a/python/tvm/dlight/benchmark/__init__.py 
b/python/tvm/dlight/benchmark/__init__.py
new file mode 100644
index 0000000000..a21e2590db
--- /dev/null
+++ b/python/tvm/dlight/benchmark/__init__.py
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Benchmarking dynamic shape workloads"""
+from .bench import benchmark, benchmark_prim_func, benchmark_relax_func
+from .extract import (
+    extract_prim_func,
+    extract_from_relax,
+    extract_func_info_from_prim_func,
+    extract_all_func_info_from_relax,
+)
diff --git a/python/tvm/dlight/benchmark/bench.py 
b/python/tvm/dlight/benchmark/bench.py
new file mode 100644
index 0000000000..850a1b46d6
--- /dev/null
+++ b/python/tvm/dlight/benchmark/bench.py
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Extract self-contained benchmarking scripts for dynamic shape workloads"""
+
+from typing import TYPE_CHECKING, Dict, List, Union, Callable, Tuple, Optional
+
+
+import tvm
+from tvm import relax
+from tvm.ir import IRModule
+from tvm.tir import PrimFunc
+from tvm.meta_schedule.runner import EvaluatorConfig
+from tvm.meta_schedule.testing.tune_utils import generate_input_data
+from tvm.testing import rpc_run
+
+
+from .extract import extract_all_func_info_from_relax, 
extract_func_info_from_prim_func
+from .utils import (
+    populuate_input_shape,
+    default_dym_var_sample_func,
+    get_func_name_from_gv,
+    dym_var_sample_str,
+    print_results,
+)
+
+if TYPE_CHECKING:
+    from tvm.meta_schedule.runner import RPCConfig
+
+
+def benchmark(
+    mod_or_func: Union[PrimFunc, IRModule],
+    *,
+    dym_var_sample: Dict[str, int],
+    args: Optional[List[Union[relax.TensorStructInfo, Tuple[Tuple[Union[int, 
str], ...], str]]]],
+    target: Optional[Union[str, tvm.target.Target]] = None,
+    func_name: Optional[str] = None,
+    evaluator_config: Optional["EvaluatorConfig"] = None,
+    rpc_config: Optional["RPCConfig"] = None,
+) -> Tuple[List[Tuple[Tuple[int, ...], str]], float, float]:
+    """Benchmark a PrimFunc or IRModule with dynamic input shapes.
+
+    Parameters
+    ----------
+    mod_or_func : Union[PrimFunc, IRModule]
+        The PrimFunc or IRModule to be benchmarked.
+    dym_var_sample : Optional[Dict[str, int]]
+        The dynamic shape variable sample, e.g., {"n": 64, "m": 128}.
+    args : Optional[List[Union[relax.TensorStructInfo, Tuple[Tuple[Union[int, 
str], ...], str]]]]
+        The input tensor information, including shape and dtype. If none, will 
use
+        the input information from the PrimFunc or IRModule.
+    target : Optional[Union[str, tvm.target.Target]]
+        The target to be benchmarked on, if none, will get the target from 
context.
+    func_name : Optional[str]
+        The name of the function to be benchmarked, will use "main" by default.
+    evaluator_config : Optional["EvaluatorConfig"]
+        The evaluator configuration to use.
+        If none, will use default evaluator configuration.
+    rpc_config : Optional["RPCConfig"]
+        The RPC configuration to connect to the remote device.
+        If none, will use local mode.
+
+    Returns
+    -------
+    input_infos : List[Tuple[Tuple[int, ...], str]]
+        The input tensor information, including shape and dtype.
+    median : float
+        The median of the benchmarking results.
+    std : float
+        The standard deviation of the benchmarking results.
+    """
+    # produce IRModule and function name
+    if isinstance(mod_or_func, PrimFunc):
+        func_name = "main" if func_name is None else func_name
+        mod = IRModule.from_expr(mod_or_func.with_attr("global_symbol", 
func_name))
+    else:
+        mod = mod_or_func
+        # assume only one global function
+        (func_name,) = mod.get_global_vars()
+    # produce input shapes
+    if args is None:
+        args, _ = extract_func_info_from_prim_func(mod[func_name])
+    # produce target & device
+    target = tvm.target.Target.current() if target is None else 
tvm.target.Target(target)
+    if target is None:
+        raise ValueError("Target is not specified")
+    if target.kind.name == "llvm":
+        dev = tvm.cpu()
+    elif target.kind.name == "cuda":
+        dev = tvm.cuda()
+    else:
+        raise ValueError(f"Unsupported device type from {target.kind.name}")
+    # populate input shapes
+    input_infos = populuate_input_shape(args, dym_var_sample)
+    # generate input tensors, including scalars
+    # scalars are appended to the end of the list due to parsing order
+    input_tensors: List[Union[tvm.nd.NDArray, int]] = []
+    scalar_input_tensors: List[int] = []
+    for input_shape, input_dtype in input_infos:
+        if input_dtype == "scalar":
+            # special case like [n], generate int value
+            assert len(input_shape) == 1
+            scalar_input_tensors.append(input_shape[0])
+        else:
+            # normal case like [1, n, 128], generate random tensor
+            input_tensors.append(
+                tvm.nd.array(generate_input_data(list(input_shape), 
input_dtype), device=dev)
+            )
+    # append scalar input tensors for rotary embedding
+    input_tensors.extend(scalar_input_tensors)
+    # build locally
+    rt_mod = tvm.build(mod, target=target)
+    # set up evaluator config
+    evaluator_config = EvaluatorConfig._normalized(  # pylint: 
disable=protected-access
+        evaluator_config
+    )
+    # run benchmark
+    if rpc_config is None:
+        profile_result = rt_mod.time_evaluator(
+            func_name,
+            dev=dev,
+            number=evaluator_config.number,
+            repeat=evaluator_config.repeat,
+            min_repeat_ms=evaluator_config.min_repeat_ms,
+            f_preproc="cache_flush_cpu_non_first_arg"
+            if evaluator_config.enable_cpu_cache_flush
+            else "",
+        )(*input_tensors)
+    else:
+        _, profile_result = rpc_run(
+            rt_mod,
+            device_type=dev.MASK2STR[dev.device_type],
+            args=[w.numpy() if isinstance(w, tvm.nd.NDArray) else w for w in 
input_tensors],
+            rpc_config=rpc_config,
+            evaluator_config=evaluator_config,
+        )
+    # return input infos, median, std
+    return input_infos, profile_result.median, profile_result.std
+
+
+def benchmark_prim_func(
+    mod_or_func: Union[PrimFunc, IRModule],
+    *,
+    dym_var_sample_func: Callable[[Dict[str, str]], Dict[str, int]] = 
default_dym_var_sample_func,
+    args: Optional[
+        List[Union[relax.TensorStructInfo, Tuple[Tuple[Union[int, str], ...], 
str]]]
+    ] = None,
+    dym_var_dict: Optional[Dict[str, str]] = None,
+    sample_number: int = 5,
+    target: Optional[Union[str, tvm.target.Target]] = None,
+    weight: Optional[int] = 1,
+    relax_func_name: Optional[str] = None,
+    prim_func_name: Optional[str] = None,
+    evaluator_config: Optional["EvaluatorConfig"] = None,
+    rpc_config: Optional["RPCConfig"] = None,
+    sort_by: Optional[str] = None,
+    desc: Optional[bool] = True,
+):
+    """Benchmark a PrimFunc or IRModule with dynamic input shapes and show 
results.
+
+    Parameters
+    ----------
+    mod_or_func : Union[PrimFunc, IRModule]
+        The PrimFunc or IRModule to be benchmarked.
+    dym_var_sample_func : Callable[[Dict[str, str]], Dict[str, int]]
+        The function to sample dynamic shape variables.
+    dym_var_dict : Optional[Dict[str, str]]
+        Dynamic shape variable dictionary, e.g., {"n": "int32", "m": "int32"}. 
If none, will use
+        the input information from the PrimFunc or IRModule.
+    args : Optional[List[Union[relax.TensorStructInfo, Tuple[Tuple[Union[int, 
str], ...], str]]]]
+        The input tensor information, including shape and dtype. If none, will 
use
+        the input information from the PrimFunc or IRModule.
+    sample_number : int
+        The number of times to sample dynamic shape variables.
+    target: Optional[Union[str, tvm.target.Target]]
+        The target to be benchmarked on, if none, will get the target from 
context.
+    weight : Optional[int]
+        The weight of this PrimFunc.
+    relax_func_name : Optional[str]
+        The name of the relax function.
+    prim_func_name : Optional[str]
+        The name of the PrimFunc.
+    evaluator_config : Optional["EvaluatorConfig"]
+        The evaluator configuration to use.
+        If none, will use default evaluator configuration.
+    rpc_config : Optional["RPCConfig"]
+        The RPC configuration to connect to the remote device.
+        If none, will use local mode.
+    sort_by : Optional[str]
+        Sort results by this key, if None, no sorting.
+    desc : Optional[bool]
+        Whether to sort results in descending order.
+    """
+    results = []
+    if dym_var_dict is None or args is None:
+        args, dym_var_dict = extract_func_info_from_prim_func(mod_or_func)
+    for _ in range(sample_number):
+        dym_var_sample = dym_var_sample_func(dym_var_dict)
+        _, median, std = benchmark(
+            mod_or_func,
+            args=args,
+            dym_var_sample=dym_var_sample,
+            target=target,
+            evaluator_config=evaluator_config,
+            rpc_config=rpc_config,
+        )
+        row = {
+            "InputInfo": ", ".join([f"{k} = {v}" for k, v in 
dym_var_sample.items()]),
+            "Time(us)": median * 1e6,
+            "Std(us)": std * 1e6,
+        }
+        if relax_func_name is not None:
+            row["RelaxFunc"] = relax_func_name
+        if prim_func_name is not None:
+            row["PrimFunc"] = prim_func_name
+        weight = 1 if weight is None else weight
+        row["Weight"] = weight
+        row["WxTime(ms)"] = weight * median * 1e3
+        results.append(row)
+    print_results(results, sort_by=sort_by, desc=desc)
+
+
+def benchmark_relax_func(
+    mod: tvm.ir.IRModule,
+    relax_func: Union[tvm.ir.GlobalVar, str],
+    sample_number: int = 2,
+    dym_var_sample_func: Callable[
+        [Dict[str, str]],
+        Dict[str, int],
+    ] = default_dym_var_sample_func,
+    target: Union[str, tvm.target.Target] = "llvm -num-cores=4",
+    evaluator_config: Optional["EvaluatorConfig"] = None,
+    rpc_config: Optional["RPCConfig"] = None,
+) -> None:
+    """Benchmark a relax function with dynamic input shapes.
+
+    Parameters
+    ----------
+    mod : tvm.ir.IRModule
+        The IRModule to be benchmarked.
+    relax_func : Union[tvm.ir.GlobalVar, str]
+        The relax function to be benchmarked.
+    sample_number : int
+        The number of times to sample dynamic shape variables.
+    dym_var_sample_func : Callable[[Dict[str, str]], Dict[str, int]]
+        The function to sample dynamic shape variables.
+    target : Union[str, tvm.target.Target]
+        The target to be benchmarked on.
+    dev : tvm.runtime.Device
+        The device to be benchmarked on.
+    evaluator_config : Optional["EvaluatorConfig"]
+        The evaluator configuration to use.
+        If none, will use default evaluator configuration.
+    rpc_config : Optional["RPCConfig"]
+        The RPC configuration to connect to the remote device.
+    """
+    # extract function information
+    relax_funcs, dynamic_var_dict = extract_all_func_info_from_relax(mod)
+    # find the relax function global var
+    if isinstance(relax_func, str):
+        for gv in relax_funcs:  # pylint: disable=invalid-name
+            if get_func_name_from_gv(gv) == relax_func:
+                relax_func = gv
+                break
+        if not isinstance(relax_func, tvm.ir.GlobalVar):
+            raise ValueError(
+                f"Cannot find relax function with name {relax_func}, "
+                + f"candidates are: {[get_func_name_from_gv(gv) for gv in 
relax_funcs]}"
+            )
+    # benchmark
+    for _ in range(sample_number):
+        dym_var_sample = dym_var_sample_func(dynamic_var_dict[relax_func])
+        bench_results = []
+        # enumerate all functors
+        for functor in relax_funcs[relax_func]:
+            for args, weight in relax_funcs[relax_func][functor]:
+                _, median, _ = benchmark(
+                    mod[functor],
+                    args=args,
+                    dym_var_sample=dym_var_sample,
+                    target=target,
+                    evaluator_config=evaluator_config,
+                    rpc_config=rpc_config,
+                )
+                bench_results.append(
+                    {
+                        f"PrimFuncs in {get_func_name_from_gv(relax_func)}": 
get_func_name_from_gv(
+                            functor
+                        ),
+                        f"InputInfo({dym_var_sample_str(dym_var_sample)})": ", 
".join(
+                            [str(w) for w in args]
+                        ),
+                        "Time(us)": median * 1e6,
+                        # "Std(us)": std * 1e6,
+                        "Weight": weight,
+                        "WxTime(ms)": median * weight * 1e3,
+                    }
+                )
+        print_results(bench_results)
diff --git a/python/tvm/dlight/benchmark/extract.py 
b/python/tvm/dlight/benchmark/extract.py
new file mode 100644
index 0000000000..df9b06e3fd
--- /dev/null
+++ b/python/tvm/dlight/benchmark/extract.py
@@ -0,0 +1,351 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Performance debug tool for dynamic shape workloads"""
+
+from typing import List, Dict, Union, Tuple, Optional
+from pathlib import Path
+
+import cloudpickle
+
+import tvm
+from tvm import relax
+from .utils import default_dym_var_sample_func, get_func_name_from_gv
+
+SKETCH = """import pickle
+
+import tvm
+from tvm import relax
+from tvm.script import tir as T
+
+from tvm.dlight.benchmark import benchmark_prim_func
+
+MODEL_NAME = "{model_name}"
+RELAX_FUNC_NAME = "{relax_func_name}"
+PRIM_FUNC_NAME = "{prim_func_name}"
+FUNC_HASH = {func_hash}
+WEIGHT = {weight}
+SAMPLE_NUMBER = {sample_number}
+
+DYM_VAR_SAMPLE_FUNC = {dym_var_sample_func}
+
+# None means extract from PrimFunc
+INPUT_ARGS = {input_args}
+DYM_VAR_DICT = {dym_var_dict}
+
+{func_script}
+
+if __name__ == "__main__":
+    target = tvm.target.Target("{target}")
+    benchmark_prim_func(
+        main,
+        args = INPUT_ARGS,
+        dym_var_dict = DYM_VAR_DICT,
+        dym_var_sample_func = DYM_VAR_SAMPLE_FUNC,
+        sample_number = SAMPLE_NUMBER,
+        target = target,
+        weight = WEIGHT,
+        relax_func_name = RELAX_FUNC_NAME,
+        prim_func_name = PRIM_FUNC_NAME,
+    )
+"""
+
+
+def extract_shape(
+    arg: Union[Tuple, List, relax.Tuple, relax.ShapeStructInfo]
+) -> List[relax.ShapeStructInfo]:
+    """Extract shape information from a relax argument.
+
+    Parameters
+    ----------
+    arg : Union[Tuple, List, relax.Tuple, relax.ShapeStructInfo]
+        The relax argument to be extracted.
+
+    Returns
+    -------
+    result : List[relax.ShapeStructInfo]
+        The extracted shape information.
+    """
+    if isinstance(arg, (tuple, list, tvm.relax.Tuple)):
+        results = []
+        for sub_arg in arg:
+            results.extend(extract_shape(sub_arg))
+        return results
+    return [arg.struct_info]
+
+
+def extract_dynamic_var(
+    func_dict: Dict[
+        tvm.ir.GlobalVar,
+        Dict[
+            tvm.ir.GlobalVar,
+            List[Tuple[List, int]],
+        ],
+    ],
+) -> Dict[tvm.ir.GlobalVar, Dict[str, str]]:
+    """Extract dynamic shape variables from a relax function dictionary.
+
+    Parameters
+    ----------
+    func_dict : Dict[
+        tvm.ir.GlobalVar,
+        Dict[
+            tvm.ir.GlobalVar,
+            List[Tuple[List, int]],
+        ],
+        The relax function dictionary, containing the input arguments' shape 
information of each
+        PrimFunc in a Relax function.
+
+    Returns
+    -------
+    result : Dict[tvm.ir.GlobalVar, Dict[str, str]]
+        The dictionary of dynamic shape variables. Given in format {"n": 
"int32", "m": "int32"}.
+    """
+    dym_var_dict: Dict[tvm.ir.GlobalVar, Dict[str, str]] = {}
+    for gv in func_dict:  # pylint: disable=invalid-name,too-many-nested-blocks
+        dym_var_dict[gv] = {}
+        for functor in func_dict[gv]:
+            for arg_list, _ in func_dict[gv][functor]:
+                for arg in arg_list:
+                    if isinstance(arg, relax.TensorStructInfo):
+                        for val in arg.shape.values:
+                            if isinstance(val, tvm.tir.Var):
+                                dym_var_dict[gv][str(val)] = val.dtype
+                    elif isinstance(arg, relax.ShapeStructInfo):
+                        for val in arg.values:
+                            if isinstance(val, tvm.tir.Var):
+                                dym_var_dict[gv][str(val)] = val.dtype
+                    else:
+                        raise NotImplementedError
+    return dym_var_dict
+
+
+def update_records(
+    records: Dict[List[relax.ShapeStructInfo], int], new_args: 
List[relax.ShapeStructInfo]
+) -> None:
+    """Update the count of a function input argument config.
+
+    Parameters
+    ----------
+    records : Dict[List[relax.ShapeStructInfo], int]
+        The dictionary to count how many times a function input argument 
config appears.
+    new_args : List[relax.ShapeStructInfo]
+        The new input argument config.
+    """
+    for i, (args, count) in enumerate(records):
+        if new_args == args:
+            records[i] = (args, count + 1)
+            return
+    records.append((new_args, 1))
+
+
+def extract_func_info_from_prim_func(
+    func: tvm.tir.PrimFunc,
+) -> Tuple[List[Tuple[Tuple[Union[tvm.tir.Var, int], ...], str]], Dict[str, 
str]]:
+    """Extract function input information from a PrimFunc.
+
+    Parameters
+    ----------
+    func : tvm.tir.PrimFunc
+        The PrimFunc to be analyzed.
+
+    Returns
+    -------
+    result : Tuple[
+        List[Tuple[Tuple[Union[tvm.tir.Var, int], ...], str]],
+        Dict[str, str],
+    ]
+        The function input information and dynamic shape variable dictionary.
+    """
+    func_args = []
+    dym_var = {}
+    for param in func.params:
+        buffer = func.buffer_map[param]
+        shape = []
+        for dim in buffer.shape:
+            if isinstance(dim, tvm.tir.IntImm):
+                shape.append(dim.value)
+            elif isinstance(dim, tvm.tir.Var):
+                dym_var[str(dim)] = str(dim.dtype)
+                shape.append(dim)
+            else:
+                raise ValueError(f"Unknown shape: {buffer.shape}")
+        func_args.append((tuple(shape), str(buffer.dtype)))
+    return func_args, dym_var
+
+
+def extract_all_func_info_from_relax(
+    mod: tvm.ir.IRModule,
+) -> Tuple[
+    Dict[tvm.ir.GlobalVar, Dict[tvm.ir.GlobalVar, List[Tuple[List, int]]]],
+    Dict[tvm.ir.GlobalVar, Dict[str, str]],
+]:
+    """Extract function input information from a relax module.
+
+    Parameters
+    ----------
+    mod : tvm.ir.IRModule
+        The Relax module to be analyzed.
+
+    Returns
+    -------
+    result : Tuple[
+        Dict[tvm.ir.GlobalVar, Dict[tvm.ir.GlobalVar, List[Tuple[List, int]]]],
+        Dict[tvm.ir.GlobalVar, Dict[str, str]],
+    ]
+        The function input information and dynamic shape variable dictionary.
+    """
+    relax_func_dict: Dict[tvm.ir.GlobalVar, Dict[tvm.ir.GlobalVar, 
List[Tuple[List, int]]]] = {}
+    for gv, func in mod.functions.items():  # pylint: 
disable=invalid-name,too-many-nested-blocks
+        if isinstance(func, tvm.relax.Function):
+            for block in func.body.blocks:
+                for binding in block.bindings:
+                    if isinstance(binding.value, tvm.relax.expr.Call):
+                        raw_args = binding.value.args
+                        functor = raw_args[0]
+                        if isinstance(functor, tvm.ir.GlobalVar) and 
isinstance(
+                            mod.functions[functor], tvm.tir.PrimFunc
+                        ):
+                            args = extract_shape(raw_args[1:]) + 
extract_shape(binding.value)
+                            if isinstance(functor, tvm.ir.GlobalVar):
+                                if not gv in relax_func_dict:
+                                    relax_func_dict[gv] = {}
+                                if not functor in relax_func_dict[gv]:
+                                    relax_func_dict[gv][functor] = []
+                                update_records(relax_func_dict[gv][functor], 
args)
+
+    return relax_func_dict, extract_dynamic_var(relax_func_dict)
+
+
+def extract_prim_func(  # pylint: disable=too-many-arguments
+    model_name: str,
+    relax_func_name: str,
+    prim_func_name: str,
+    func: tvm.tir.PrimFunc,
+    *,
+    func_args: Optional[List[Tuple[Tuple[Union[tvm.relax.expr.Call, int], 
...], str]]] = None,
+    dym_var_dict: Optional[Dict[str, str]] = None,
+    weight: int = 1,
+    sample_number: int = 5,
+    target: Optional[Union[str, tvm.target.Target]] = None,
+) -> str:
+    """Extract a self-contained PrimFunc test file from a Relax module.
+
+    Parameters
+    ----------
+    model_name: str
+        The name of the model.
+    relax_func_name: str
+        The name of the Relax function.
+    prim_func_name: str
+        The name of the prim function.
+    func: tvm.tir.PrimFunc
+        The PrimFunc to be extracted.
+    func_args: Optional[List[Tuple[Tuple[Union[tvm.relax.expr.Call, int], 
...], str]]]
+        The arguments of the prim function, including both static and dynamic 
shape arguments.
+        Given in format [ ..., ((1, n, 128), "float32"), ... ].
+        If not given, the arguments will be extracted from the PrimFunc.
+    dym_var_dict: Optional[Dict[str, str]]
+        The dictionary of dynamic shape variables. Given in format {"n": 
"int32", "m": "int32"}.
+        If not given, the dictionary will be extracted from the PrimFunc.
+    weight: int
+        The weight of the prim function, by default 1.
+    sample_number: int
+        The number of times to sample dynamic shape variables, by default 5.
+    target: Optional[Union[str, tvm.target.Target]]
+        The target device to run the PrimFunc. If None, will use target from 
the context.
+
+    Returns
+    -------
+    result : str
+        The extracted PrimFunc test file content.
+    """
+    if target is None:
+        target = tvm.target.Target.current()
+        target_str = str(target)
+        if target is None:
+            raise ValueError("Target is not specified.")
+    elif isinstance(target, str):
+        target_str = target
+        target = tvm.target.Target(target)
+    elif isinstance(target, tvm.target.Target):
+        target_str = str(target)
+    else:
+        raise TypeError("Unsupported target type: " + str(type(target)))
+
+    return SKETCH.format(
+        **{
+            "model_name": model_name,
+            "relax_func_name": relax_func_name,
+            "prim_func_name": prim_func_name,
+            "func_hash": tvm.ir.structural_hash(func),
+            "weight": weight,
+            "sample_number": sample_number,
+            "dym_var_dict": f"pickle.loads({cloudpickle.dumps(dym_var_dict)})"
+            if dym_var_dict is not None
+            else "None",
+            "input_args": f"pickle.loads({cloudpickle.dumps(func_args)})" if 
func_args else "None",
+            "dym_var_sample_func": "pickle.loads("
+            + f"{cloudpickle.dumps(default_dym_var_sample_func)}"
+            + ")",
+            "func_script": func.script(),
+            "target": target_str,
+        }
+    )
+
+
+def extract_from_relax(
+    mod: tvm.ir.IRModule,
+    model_name: str,
+    file_path: str,
+    target: Optional[Union[str, tvm.target.Target]] = None,
+) -> None:
+    """Extract self-contained PrimFunc test files from a Relax module.
+
+    Parameters
+    ----------
+    mod: tvm.ir.IRModule
+        The Relax module to be extracted.
+    model_name: str
+        The name of the model.
+    file_path: str
+        The path to store the extracted files.
+    target: Optional[Union[str, tvm.target.Target]]
+        The target device to run the PrimFunc. If None, will use target from 
the context.
+    """
+    relax_funcs, dym_var_dict = extract_all_func_info_from_relax(mod)
+    Path(file_path).mkdir(parents=True, exist_ok=True)
+    for relax_func_gv in relax_funcs:  # pylint: 
disable=consider-using-dict-items
+        relax_func_name = get_func_name_from_gv(relax_func_gv)
+        for prim_func_gv in relax_funcs[relax_func_gv]:
+            prim_func_name = get_func_name_from_gv(prim_func_gv)
+            for func_args, weight in relax_funcs[relax_func_gv][prim_func_gv]:
+                with open(
+                    f"{file_path}/{relax_func_name}_{prim_func_name}.py", "w", 
encoding="utf-8"
+                ) as file:
+                    print(
+                        extract_prim_func(
+                            model_name=model_name,
+                            relax_func_name=relax_func_name,
+                            prim_func_name=prim_func_name,
+                            func=mod[prim_func_gv],
+                            dym_var_dict=dym_var_dict[relax_func_gv],
+                            func_args=func_args,
+                            weight=weight,
+                            target=target,
+                        ),
+                        file=file,
+                    )
diff --git a/python/tvm/dlight/benchmark/utils.py 
b/python/tvm/dlight/benchmark/utils.py
new file mode 100644
index 0000000000..72e0ac8de0
--- /dev/null
+++ b/python/tvm/dlight/benchmark/utils.py
@@ -0,0 +1,172 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Util functions for benchmarking dynamic shape workloads"""
+
+from typing import Dict, List, Tuple, Union, Any
+
+import tvm
+from tvm import relax
+
+INPUT_SHAPE_TYPE = List[Tuple[Tuple[int, ...], str]]  # pylint: 
disable=invalid-name
+
+
+def get_func_name_from_gv(gv: tvm.ir.GlobalVar) -> str:  # pylint: 
disable=invalid-name
+    """Get function name from a global variable.
+
+    Parameters
+    ----------
+    gv : tvm.ir.GlobalVar
+        The given global variable.
+
+    Returns
+    -------
+    result : str
+        The global variable name without the prefix "...@".
+    """
+    return gv.astext().split("@")[1] if "@" in gv.astext() else gv.astext()
+
+
+def dym_var_sample_str(sample: Dict[Union[str, tvm.relax.expr.Call], int]) -> 
str:
+    """Convert a variable value sample to a string.
+
+    Parameters
+    ----------
+    sample : Dict[Union[str, tvm.relax.expr.Call], int]
+        Variable value sample, e.g., {n: 64, m: 128} or {"n": 64, "m": 128}
+
+    Returns
+    -------
+    result : str
+        Variable value sample string, e.g., "n=64, m=128"
+    """
+    return ", ".join([f"{k}={v}" for k, v in sample.items()])
+
+
+def populuate_input_shape(
+    input_infos: List[Union[relax.TensorStructInfo, Tuple[Tuple[Union[int, 
str], ...], str]]],
+    dym_var_sample: Dict[str, int],
+) -> INPUT_SHAPE_TYPE:
+    """
+    Populate input shapes with dynamic shape variable samples.
+
+    Parameters
+    ----------
+    input_infos : List[Union[relax.TensorStructInfo, Tuple[Tuple[Union[int, 
str], ...], str]]]
+        Input tensor information, including shape and dtype,
+        e.g., [..., Shape(1, n, 128) with dtype="int32", ...]
+    dym_var_sample : Dict[str, int]
+        Dynamic shape variable sample, e.g., {"n": 64}
+
+    Returns
+    -------
+    results : INPUT_SHAPE_TYPE
+        Input shapes with dynamic shape variable samples, e.g.,
+        [..., ((1, 64, 128), "int32"), ...] if n=64 or
+        [..., (128, "scalar"), ...] if n=128 for scalar input
+    """
+    results: INPUT_SHAPE_TYPE = []
+    for input_info in input_infos:
+        shape = []
+        if isinstance(input_info, relax.struct_info.ShapeStructInfo):
+            # scalar input
+            results.append(((dym_var_sample[str(input_info.values[0])],), 
"scalar"))
+        else:
+            if isinstance(input_info, relax.TensorStructInfo):
+                tensor_shape = input_info.shape
+                tensor_dtype = input_info.dtype
+            else:
+                tensor_shape, tensor_dtype = input_info  # type: ignore
+            for dim in tensor_shape:
+                if isinstance(dim, int):
+                    shape.append(dim)
+                elif isinstance(dim, tvm.tir.IntImm):
+                    shape.append(dim.value)
+                else:
+                    shape.append(dym_var_sample[str(dim)])
+            results.append(((*shape,), tensor_dtype))
+    return results
+
+
+def default_dym_var_sample_func(dym_var_dict: Dict[str, str]) -> Dict[str, 
int]:
+    """
+    Default dynamic shape variable sample function.
+    Sample a random value for each dynamic shape variable.
+
+    Parameters
+    ----------
+    dym_var_dict : Dict[str, str]
+        Dynamic shape variable dictionary, e.g., {"n": "int32", "m": "int32"}
+
+    Returns
+    -------
+    result : Dict[str, int]
+        Dynamic shape variable sample, e.g., {"n": 64, "m": 128}
+    """
+    results = {}
+    for var in dym_var_dict:
+        if dym_var_dict[var] in ["int32", "int64"]:
+            import random  # pylint: disable=import-outside-toplevel
+
+            results[var] = random.randint(2, 128)
+        else:
+            raise TypeError("Unsupported dynamic shape variable type: " + 
dym_var_dict[var])
+    return results
+
+
+def print_results(
+    bench_results: List[Dict[str, Any]], sort_by: str = "WxTime(ms)", desc: 
bool = True
+):
+    """Print benchmark results.
+
+    Parameters
+    ----------
+    bench_results : List[Dict[str, Any]]
+        Benchmark results as dictionary list.
+    sort_by : str
+        Sort results by this key, if None, no sorting.
+    desc : bool
+        Whether to sort results in descending order.
+    """
+    # pylint: disable=invalid-name, import-outside-toplevel
+    try:
+        import pandas as pd
+
+        df = pd.DataFrame()
+        for record in bench_results:
+            df = pd.concat(
+                [df, pd.DataFrame(record, index=[0])],
+                ignore_index=True,
+            )
+        if sort_by is not None:
+            if sort_by not in df.columns:
+                raise ValueError(f"sort_by key {sort_by} not in benchmark 
results")
+            df = df.sort_values(sort_by, ascending=not 
desc).reset_index().drop("index", axis=1)
+        print(df)
+    except ModuleNotFoundError:
+        print("Pandas not found, printing results in raw format.")
+        keys = []
+        if len(bench_results) > 0:
+            for key in bench_results[0]:
+                keys.append(str(key))
+        print("\t".join(keys))
+        for record in bench_results:
+            values = []
+            for key in keys:
+                values.append(str(record[key]))
+            print("\t".join(values))
+    print("\n")
+    # pylint: enable=invalid-name, import-outside-toplevel
diff --git a/python/tvm/meta_schedule/testing/tune_utils.py 
b/python/tvm/meta_schedule/testing/tune_utils.py
index 17064c64ab..8fe6f22d11 100644
--- a/python/tvm/meta_schedule/testing/tune_utils.py
+++ b/python/tvm/meta_schedule/testing/tune_utils.py
@@ -18,7 +18,6 @@
 from typing import Callable, Optional, Union, List, Dict
 from statistics import median
 import json
-import warnings
 import numpy as np  # type: ignore
 
 import tvm
@@ -48,10 +47,6 @@ def generate_input_data(
     """
     if input_dtype.startswith("float"):
         return np.random.uniform(size=input_shape).astype(input_dtype)
-    if low is None or high is None:
-        warnings.warn(
-            f"Model input value range for shape {input_shape} of {input_dtype} 
is not set!"
-        )
     range_map = {
         "uint8": (0, 255),
         "int8": (-128, 127),
diff --git a/tests/python/dlight/test_benchmark.py 
b/tests/python/dlight/test_benchmark.py
new file mode 100644
index 0000000000..3153be2cc9
--- /dev/null
+++ b/tests/python/dlight/test_benchmark.py
@@ -0,0 +1,316 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+
+import tempfile
+import pytest
+
+from tvm import meta_schedule as ms
+from tvm.meta_schedule.testing.local_rpc import LocalRPC
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.script import relax as R
+
+
+from tvm.dlight.benchmark import (
+    benchmark,
+    benchmark_prim_func,
+    benchmark_relax_func,
+    extract_prim_func,
+    extract_from_relax,
+    extract_func_info_from_prim_func,
+)
+import tvm.testing
+
+# pylint: 
disable=no-self-argument,invalid-name,line-too-long,no-method-argument
+# fmt: off
[email protected]_module
+class Module:
+    @T.prim_func
+    def full1(var_T_full: T.handle):
+        T.func_attr({"op_pattern": 0, "tir.noalias": T.bool(True)})
+        n = T.int64()
+        T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(32), 
T.int64(1), n), "float16")
+        # with T.block("root"):
+        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), 
n):
+            with T.block("T_full"):
+                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, 
ax2, ax3])
+                T.reads()
+                T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3])
+                T_full[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(1.0)
+
+    @T.prim_func
+    def full2(var_T_full: T.handle):
+        T.func_attr({"op_pattern": 0, "tir.noalias": T.bool(True)})
+        n = T.int64()
+        T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(32), n, 
T.int64(128)), "float16")
+        # with T.block("root"):
+        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, 
T.int64(128)):
+            with T.block("T_full"):
+                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, 
ax2, ax3])
+                T.reads()
+                T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3])
+                T_full[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(1.0)
+
+    @T.prim_func
+    def matmul1(var_A: T.handle, var_B: T.handle, matmul: 
T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")):
+        T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
+        n = T.int64()
+        A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), 
"float16")
+        B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(128)), 
"float16")
+        # with T.block("root"):
+        for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), 
T.int64(128), n):
+            with T.block("matmul"):
+                v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, 
i2, i3, k])
+                T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
+                T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
+                with T.init():
+                    matmul[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
+                matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, 
v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
+
+    @R.function
+    def test():
+        n = T.int64()
+        R.func_attr({"tir_var_upper_bound": {"n": 2048}})
+        cls = Module
+        with R.dataflow():
+            lv1 = R.call_tir(cls.full1,(), out_sinfo=R.Tensor((1, 32, 1, n), 
dtype="float16"))
+            lv1_1 = R.call_tir(cls.full1,(), out_sinfo=R.Tensor((1, 32, 1, n), 
dtype="float16"))
+            lv1_2 = R.call_tir(cls.full1,(), out_sinfo=R.Tensor((1, 32, 1, n), 
dtype="float16"))
+            lv2 = R.call_tir(cls.full2,(), out_sinfo=R.Tensor((1, 32, n, 128), 
dtype="float16"))
+            lv2_1 = R.call_tir(cls.full2,(), out_sinfo=R.Tensor((1, 32, n, 
128), dtype="float16"))
+            lv3 = R.call_tir(cls.matmul1, (lv1, lv2), out_sinfo=R.Tensor((1, 
32, 1, 128), dtype="float16"))
+            R.output(lv3)
+        return lv3
+
[email protected]_func
+def cuda_workload(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), 
T.int64(4096)), "float32"), var_matmul: T.handle):
+    T.func_attr({"tir.is_scheduled": 1})
+    m = T.int64()
+    inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096)))
+    matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096)))
+    # with T.block("root"):
+    matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(31)) 
// T.int64(32) * T.int64(32), T.int64(4096)), scope="local")
+    inp0_reindex_pad_shared = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // 
T.int64(32) * T.int64(32), T.int64(4096)), scope="shared")
+    inp1_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), 
T.int64(4096)), scope="shared")
+    for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
+        for ax1_0 in T.thread_binding((m + T.int64(31)) // T.int64(32), 
thread="blockIdx.x"):
+            for ax2_0 in T.thread_binding(T.int64(64), thread="blockIdx.y"):
+                for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
+                    for ax1_1 in T.thread_binding(T.int64(1), 
thread="vthread.x"):
+                        for ax2_2 in T.thread_binding(T.int64(16), 
thread="threadIdx.y"):
+                            for ax1_2 in T.thread_binding(T.int64(8), 
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, 
"pragma_unroll_explicit": 1}):
+                                for ax2_3_init, ax1_3_init in 
T.grid(T.int64(4), T.int64(4)):
+                                    with T.block("matmul_init"):
+                                        v0 = T.axis.spatial(T.int64(1), ax0)
+                                        v1 = T.axis.spatial((m + T.int64(31)) 
// T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 
* T.int64(4) + ax1_3_init)
+                                        v2 = T.axis.spatial(T.int64(4096), 
ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
+                                        T.reads()
+                                        
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
+                                        matmul_reindex_pad_local[T.int64(0), 
v1, v2] = T.float32(0)
+                                for ax3_0 in range(T.int64(256)):
+                                    for ax0_ax1_ax2_fused_0 in 
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                        for ax0_ax1_ax2_fused_1 in 
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                            for ax0_ax1_ax2_fused_2 in 
range(T.int64(2)):
+                                                for ax0_ax1_ax2_fused_3 in 
T.vectorized(T.int64(2)):
+                                                    with 
T.block("inp0_reindex_pad_shared"):
+                                                        v0 = 
T.axis.spatial(T.int64(1), T.int64(0))
+                                                        v1 = T.axis.spatial((m 
+ T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + 
(ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + 
ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                        v2 = 
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * 
T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                        T.reads(inp0[v0, v1, 
v2])
+                                                        
T.writes(inp0_reindex_pad_shared[v0, v1, v2])
+                                                        
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+                                                        
inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], 
T.float32(0))
+                                    for ax0_ax1_ax2_fused_0 in 
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                        for ax0_ax1_ax2_fused_1 in 
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                            for ax0_ax1_ax2_fused_2 in 
range(T.int64(4)):
+                                                for ax0_ax1_ax2_fused_3 in 
T.vectorized(T.int64(2)):
+                                                    with 
T.block("inp1_reindex_shared"):
+                                                        v0 = 
T.axis.spatial(T.int64(1), T.int64(0))
+                                                        v1 = 
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 * 
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                        v2 = 
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * 
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                        T.reads(inp1[v2, v1])
+                                                        
T.writes(inp1_reindex_shared[v0, v1, v2])
+                                                        
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+                                                        
inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1]
+                                    for ax3_1, ax2_3, ax1_3 in 
T.grid(T.int64(16), T.int64(4), T.int64(4)):
+                                        with T.block("matmul_update"):
+                                            v0 = T.axis.spatial(T.int64(1), 
ax0)
+                                            v1 = T.axis.spatial((m + 
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * 
T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
+                                            v2 = T.axis.spatial(T.int64(4096), 
ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3)
+                                            v3 = T.axis.reduce(T.int64(4096), 
ax3_0 * T.int64(16) + ax3_1)
+                                            
T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2], 
inp0_reindex_pad_shared[T.int64(0), v1, v3], inp1_reindex_shared[T.int64(0), 
v2, v3])
+                                            
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
+                                            
matmul_reindex_pad_local[T.int64(0), v1, v2] = 
matmul_reindex_pad_local[T.int64(0), v1, v2] + 
inp0_reindex_pad_shared[T.int64(0), v1, v3] * inp1_reindex_shared[T.int64(0), 
v2, v3]
+                                for ax0_1, ax1, ax2_0_1 in T.grid(T.int64(1), 
T.int64(4), T.int64(2)):
+                                    for ax2_1_1 in T.vectorized(T.int64(2)):
+                                        with 
T.block("matmul_reindex_pad_local"):
+                                            v0 = T.axis.spatial(T.int64(1), 
ax0_1)
+                                            v1 = T.axis.spatial((m + 
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * 
T.int64(4) + ax1)
+                                            v2 = T.axis.spatial(T.int64(4096), 
ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + ax2_0_1 * T.int64(2) + ax2_1_1)
+                                            
T.reads(matmul_reindex_pad_local[v0, v1, v2])
+                                            T.writes(matmul[T.int64(0), v1, 
v2])
+                                            if v1 < m:
+                                                matmul[T.int64(0), v1, v2] = 
matmul_reindex_pad_local[v0, v1, v2]
+# fmt: on
+# pylint: enable=no-self-argument,invalid-name,line-too-long,no-method-argument
+
+
[email protected]("requires CUDA")
+def test_benchmark_prim_func_rpc():
+    with LocalRPC() as rpc:
+        rpc_config = ms.runner.RPCConfig(
+            tracker_host=rpc.tracker_host,
+            tracker_port=rpc.tracker_port,
+            tracker_key=rpc.tracker_key,
+            session_priority=1,
+            session_timeout_sec=100,
+        )
+        input_infos, _, _ = benchmark(
+            cuda_workload,
+            args=[
+                ((1, "m", 4096), "float32"),
+                ((4096, 4096), "float32"),
+                ((1, "m", 4096), "float32"),
+            ],
+            dym_var_sample={"m": 128},
+            target="nvidia/geforce-rtx-3070",
+            rpc_config=rpc_config,
+        )
+        assert input_infos == [
+            ((1, 128, 4096), "float32"),
+            ((4096, 4096), "float32"),
+            ((1, 128, 4096), "float32"),
+        ]
+
+
[email protected]("requires CUDA")
+def test_benchmark_prim_func_local():
+    input_infos, _, _ = benchmark(
+        cuda_workload,
+        args=[
+            ((1, "m", 4096), "float32"),
+            ((4096, 4096), "float32"),
+            ((1, "m", 4096), "float32"),
+        ],
+        dym_var_sample={"m": 128},
+        target="nvidia/geforce-rtx-3070",
+    )
+    assert input_infos == [
+        ((1, 128, 4096), "float32"),
+        ((4096, 4096), "float32"),
+        ((1, 128, 4096), "float32"),
+    ]
+
+
[email protected]("requires CUDA")
+def test_benchmark_prim_func_full_local():
+    with tvm.target.Target("nvidia/geforce-rtx-3070"):
+        benchmark_prim_func(
+            cuda_workload,
+        )
+
+
[email protected]("requires CUDA")
+def test_benchmark_prim_func_full_rpc():
+    with LocalRPC() as rpc:
+        rpc_config = ms.runner.RPCConfig(
+            tracker_host=rpc.tracker_host,
+            tracker_port=rpc.tracker_port,
+            tracker_key=rpc.tracker_key,
+            session_priority=1,
+            session_timeout_sec=100,
+        )
+        benchmark_prim_func(
+            cuda_workload,
+            target="nvidia/geforce-rtx-3070",
+            rpc_config=rpc_config,
+            evaluator_config=ms.runner.EvaluatorConfig(
+                number=10,
+                repeat=10,
+                min_repeat_ms=0,
+                enable_cpu_cache_flush=False,
+            ),
+        )
+
+
+def test_benchmark_relax_func():
+    with tvm.target.Target("llvm -num-cores=4"):
+        benchmark_relax_func(Module, "test")
+
+
+def test_extract_prim_func_full1():
+    print(
+        extract_prim_func(
+            model_name="TEST",
+            relax_func_name="test",
+            prim_func_name="full1",
+            func=Module["full1"],  # type: ignore
+            func_args=[((1, 32, 1, "n"), "float16")],
+            dym_var_dict={"n": "int32"},
+            weight=2,
+            sample_number=10,
+            target="llvm -num-cores=4",
+        )
+    )
+
+
+def test_extract_prim_func_matmul1():
+    print(
+        extract_prim_func(
+            model_name="TEST",
+            relax_func_name="test",
+            prim_func_name="matmul1",
+            func=Module["matmul1"],  # type: ignore
+            weight=2,
+            sample_number=10,
+            target="llvm -num-cores=4",
+        )
+    )
+
+
+def test_extract_from_relax():
+    with tvm.target.Target("llvm -num-cores=4"):
+        with tempfile.TemporaryDirectory() as filepath:
+            extract_from_relax(
+                Module,
+                "TEST",
+                file_path=filepath,
+            )
+
+
+def test_extract_func_info_from_prim_func():
+    assert (
+        str(extract_func_info_from_prim_func(cuda_workload))
+        == "([((1, m, 4096), 'float32'), ((4096, 4096), 'float32'), ((1, m, 
4096), 'float32')], {'m': 'int64'})"
+    )
+    assert (
+        str(extract_func_info_from_prim_func(Module["full1"]))
+        == "([((1, 32, 1, n), 'float16')], {'n': 'int64'})"
+    )
+    assert (
+        str(extract_func_info_from_prim_func(Module["matmul1"]))
+        == "([((1, 32, 1, n), 'float16'), ((1, 32, n, 128), 'float16'), ((1, 
32, 1, 128), 'float16')], {'n': 'int64'})"
+    )
+    assert (
+        str(extract_func_info_from_prim_func(Module["full2"]))
+        == "([((1, 32, n, 128), 'float16')], {'n': 'int64'})"
+    )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to