zxybazh commented on code in PR #11797:
URL: https://github.com/apache/tvm/pull/11797#discussion_r910464495
##########
python/tvm/meta_schedule/testing/utils.py:
##########
@@ -77,3 +81,145 @@ def apply_fixed_schedules(
database.commit_tuning_record(tune_rec)
return database
+
+
+def generate_input_data(input_shape: List[int], input_dtype: str) ->
np.ndarray:
+ """Generate input date with given shape and data type.
+
+ Parameters
+ ----------
+ input_shape : List[int]
+ The shape of the input data.
+ input_dtype : str
+ The data type of the input date.
+
+ Returns
+ -------
+ input_data : np.ndarray
+ The generated input data with given shape and data type in numpy
ndarray.
+ """
+ if input_dtype.startswith("float"):
+ return np.random.uniform(size=input_shape).astype(input_dtype)
+ if input_dtype in ["uint8", "int8"]:
+ return np.random.randint(
+ low=0,
+ high=127,
+ size=input_shape,
+ dtype="int32", # TODO(zxybazh): fix the datatype when int8 /
uint8 is supported better
+ )
+ if input_dtype in ["int32", "int64"]:
+ return np.random.randint(low=0, high=10000, size=input_shape,
dtype=input_dtype)
Review Comment:
That's right, for `gpt-2` and `bert` the embedding table is not the same. In
that case we need user input, so I'll just add the low and high as argument in
this function and throw a warning.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]