junrushao1994 commented on code in PR #11797:
URL: https://github.com/apache/tvm/pull/11797#discussion_r910456672
##########
python/tvm/meta_schedule/testing/utils.py:
##########
@@ -77,3 +81,145 @@ def apply_fixed_schedules(
database.commit_tuning_record(tune_rec)
return database
+
+
+def generate_input_data(input_shape: List[int], input_dtype: str) ->
np.ndarray:
+ """Generate input date with given shape and data type.
+
+ Parameters
+ ----------
+ input_shape : List[int]
+ The shape of the input data.
+ input_dtype : str
+ The data type of the input date.
+
+ Returns
+ -------
+ input_data : np.ndarray
+ The generated input data with given shape and data type in numpy
ndarray.
+ """
+ if input_dtype.startswith("float"):
+ return np.random.uniform(size=input_shape).astype(input_dtype)
+ if input_dtype in ["uint8", "int8"]:
+ return np.random.randint(
+ low=0,
+ high=127,
+ size=input_shape,
+ dtype="int32", # TODO(zxybazh): fix the datatype when int8 /
uint8 is supported better
+ )
+ if input_dtype in ["int32", "int64"]:
+ return np.random.randint(low=0, high=10000, size=input_shape,
dtype=input_dtype)
Review Comment:
If it's an indexing table used in embedding lookups, then `low` and `high`
here probably indicate the size of the embedding table. If not taken care of,
could cause some buffer access overflow. I don't have good ideas of how to
properly handling this, but let's add a warning here
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]