jcf94 commented on a change in pull request #7313:
URL: https://github.com/apache/tvm/pull/7313#discussion_r587971202



##########
File path: python/tvm/auto_scheduler/search_task.py
##########
@@ -157,6 +164,149 @@ def __init__(
         )
 
 
+# The map stores special registered buffer for measurement
+#  This can be used for sparse workloads when we cannot use random tensors for 
measurment.
+# {
+#     "workload_key_0": {
+#         "task_input_0": Tensor(...),
+#         "task_input_1": Tensor(...)
+#     },
+#     "workload_key_1": {
+#         "task_input_2": Tensor(...),
+#         "task_input_3": Tensor(...)
+#     },
+#     ...
+# }
+TASK_INPUT_BUFFER_TABLE = {}
+
+
+def _save_buffer_to_file(buffer_name, buffer_data):
+    """Save the current Tensor buffer to a numpy file.
+
+    File name will be: {buffer_name}.{buffer_shape}_{buffer_data_type}
+    """
+    np_data = buffer_data.asnumpy()
+
+    buffer_name += "."
+    for i in np_data.shape:
+        buffer_name += "%d_" % (i)
+    buffer_name += "%s" % (np_data.dtype)
+
+    np_data.tofile(buffer_name, " ")
+
+
+def _try_load_buffer_from_file(buffer_name):
+    """Try to load buffer from a numpy file, if not found, return None.
+
+    File name has a same format as `_save_buffer_to_file`.
+    """
+    filelist = os.listdir()
+
+    for file in filelist:
+        if file.startswith(buffer_name) and file.count("."):
+            meta_info = file.split(".")[-1].split("_")
+            shape = [int(i) for i in meta_info[:-1]]
+            dtype = meta_info[-1]
+            buffer_data = np.fromfile(file, dtype=dtype, sep=" ")
+            buffer_data = buffer_data.reshape(shape)
+            return ndarray.array(buffer_data)
+
+    return None
+
+
+def register_task_input_buffer(
+    workload_key,
+    input_name,
+    input_data,
+    overwrite=False,
+    save_to_file=False,
+):
+    """Register special buffer for measurement.
+
+    Parameters
+    ----------
+    workload_key : str
+        The workload key of the SearchTask.
+
+    input_name : str
+        The name of input buffer.
+
+    input_data : tvm.nd.NDArray
+        The input Tensor data.
+
+    overwrite : bool = False
+        Whether overwrite the data if a name has already in the global table.
+
+    save_to_file : bool = False
+        Whether record this buffer to a local file. This can be reused to 
continue the last tuning
+        process.
+    """
+    global TASK_INPUT_BUFFER_TABLE
+
+    if workload_key not in TASK_INPUT_BUFFER_TABLE:
+        TASK_INPUT_BUFFER_TABLE[workload_key] = {}
+    input_table = TASK_INPUT_BUFFER_TABLE[workload_key]
+
+    if not overwrite:
+        if input_name not in input_table.keys():
+            # Try to load buffer data from local file
+            tensor_from_file = _try_load_buffer_from_file(input_name)
+            if tensor_from_file:
+                input_table[input_name] = tensor_from_file
+
+        if input_name in input_table.keys():

Review comment:
       I actually have thought about this, but if we print at each measure the 
output information will be a mess.
   Or we can raise a warning if the inputs are missing here.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to