junrushao1994 commented on code in PR #11797:
URL: https://github.com/apache/tvm/pull/11797#discussion_r910457244
##########
python/tvm/meta_schedule/testing/utils.py:
##########
@@ -77,3 +81,145 @@ def apply_fixed_schedules(
database.commit_tuning_record(tune_rec)
return database
+
+
+def generate_input_data(input_shape: List[int], input_dtype: str) ->
np.ndarray:
+ """Generate input date with given shape and data type.
+
+ Parameters
+ ----------
+ input_shape : List[int]
+ The shape of the input data.
+ input_dtype : str
+ The data type of the input date.
+
+ Returns
+ -------
+ input_data : np.ndarray
+ The generated input data with given shape and data type in numpy
ndarray.
+ """
+ if input_dtype.startswith("float"):
+ return np.random.uniform(size=input_shape).astype(input_dtype)
+ if input_dtype in ["uint8", "int8"]:
+ return np.random.randint(
+ low=0,
+ high=127,
+ size=input_shape,
+ dtype="int32", # TODO(zxybazh): fix the datatype when int8 /
uint8 is supported better
+ )
+ if input_dtype in ["int32", "int64"]:
+ return np.random.randint(low=0, high=10000, size=input_shape,
dtype=input_dtype)
+ raise ValueError("Unsupported input datatype!")
+
+
+def f_timer(backend: str) -> Callable:
+ """Create a function to run and benchmark the performance of whole given
runtime module,
+ or Executable in relay vm.
+
+ Parameters
+ ----------
+ backend : str
+ The backend to use, graph / vm.
+
+ Returns
+ -------
+ func : Callable
+ The function to benchmark the workload.
+ """
+
+ def f_timer_func(
+ rt_mod: Union[tvm.runtime.Module, tvm.runtime.vm.Executable],
+ dev: tvm.device,
+ input_data: Dict[str, NDArray],
+ ) -> None:
+ """Run and benchmark the given runtime module, print out the result.
+
+ Parameters
+ ----------
+ rt_mod : Union[tvm.runtime.Module, tvm.runtime.vm.Executable]
+ The runtime module or vm executable.
+ dev : tvm.device
+ The device type to run workload.
+ input_data : Dict[str, np.ndarray]
+ The input data as a dictionary.
+ """
+ from tvm.contrib.graph_executor import GraphModule #
pylint:disable=import-outside-toplevel
+ from tvm.runtime.vm import VirtualMachine #
pylint:disable=import-outside-toplevel
+
+ if backend == "vm":
+ vm = VirtualMachine(rt_mod, dev) # pylint: disable=invalid-name
+ ftimer = vm.benchmark(
+ dev, min_repeat_ms=500, repeat=5, number=1, end_to_end=False,
**input_data
+ )
+ elif backend == "graph":
+ mod = GraphModule(rt_mod["default"](dev))
+ for input_name, input_value in input_data.items():
+ mod.set_input(input_name, input_value)
+ ftimer = mod.module.time_evaluator("run", dev, min_repeat_ms=500,
repeat=5, number=1)()
+ else:
+ raise ValueError(f"Backend {backend} not supported in f_timer!")
+
+ results = list(np.array(ftimer.results) * 1000.0) # type: ignore
+
+ print("Running time in time_evaluator: ", results)
+ print("-------------------------------")
+ print(f" Min (ms) : {min(results)}")
+ print(f" Max (ms) : {max(results)}")
+ print(f" Median (ms) : {median(results)}")
+ print(f"Average (ms) : {sum(results) / len(results)}")
+
+ return f_timer_func
+
+
+def f_per_layer(graph: str) -> Callable:
Review Comment:
ditto. how about `create_per_layer_timer_func`
##########
python/tvm/meta_schedule/testing/utils.py:
##########
@@ -77,3 +81,145 @@ def apply_fixed_schedules(
database.commit_tuning_record(tune_rec)
return database
+
+
+def generate_input_data(input_shape: List[int], input_dtype: str) ->
np.ndarray:
+ """Generate input date with given shape and data type.
+
+ Parameters
+ ----------
+ input_shape : List[int]
+ The shape of the input data.
+ input_dtype : str
+ The data type of the input date.
+
+ Returns
+ -------
+ input_data : np.ndarray
+ The generated input data with given shape and data type in numpy
ndarray.
+ """
+ if input_dtype.startswith("float"):
+ return np.random.uniform(size=input_shape).astype(input_dtype)
+ if input_dtype in ["uint8", "int8"]:
+ return np.random.randint(
+ low=0,
+ high=127,
+ size=input_shape,
+ dtype="int32", # TODO(zxybazh): fix the datatype when int8 /
uint8 is supported better
+ )
+ if input_dtype in ["int32", "int64"]:
+ return np.random.randint(low=0, high=10000, size=input_shape,
dtype=input_dtype)
+ raise ValueError("Unsupported input datatype!")
+
+
+def f_timer(backend: str) -> Callable:
Review Comment:
let's find a better name. how about, `create_timer`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]