merrymercy commented on a change in pull request #7313:
URL: https://github.com/apache/tvm/pull/7313#discussion_r586075755



##########
File path: python/tvm/auto_scheduler/search_task.py
##########
@@ -157,6 +164,149 @@ def __init__(
         )
 
 
+# The map stores special registered buffer for measurement
+#  This can be used for sparse workloads when we cannot use random tensors for 
measurment.
+# {
+#     "workload_key_0": {
+#         "task_input_0": Tensor(...),
+#         "task_input_1": Tensor(...)
+#     },
+#     "workload_key_1": {
+#         "task_input_2": Tensor(...),
+#         "task_input_3": Tensor(...)
+#     },
+#     ...
+# }
+TASK_INPUT_BUFFER_TABLE = {}
+
+
+def _save_buffer_to_file(buffer_name, buffer_data):
+    """Save the current Tensor buffer to a numpy file.
+
+    File name will be: {buffer_name}.{buffer_shape}_{buffer_data_type}
+    """
+    np_data = buffer_data.asnumpy()
+
+    buffer_name += "."
+    for i in np_data.shape:
+        buffer_name += "%d_" % (i)
+    buffer_name += "%s" % (np_data.dtype)
+
+    np_data.tofile(buffer_name, " ")
+
+
+def _try_load_buffer_from_file(buffer_name):
+    """Try to load buffer from a numpy file, if not found, return None.
+
+    File name has a same format as `_save_buffer_to_file`.
+    """
+    filelist = os.listdir()
+
+    for file in filelist:
+        if file.startswith(buffer_name) and file.count("."):
+            meta_info = file.split(".")[-1].split("_")
+            shape = [int(i) for i in meta_info[:-1]]
+            dtype = meta_info[-1]
+            buffer_data = np.fromfile(file, dtype=dtype, sep=" ")
+            buffer_data = buffer_data.reshape(shape)
+            return ndarray.array(buffer_data)
+
+    return None
+
+
+def register_task_input_buffer(
+    workload_key,
+    input_name,
+    input_data,
+    overwrite=False,
+    save_to_file=False,
+):
+    """Register special buffer for measurement.
+
+    Parameters
+    ----------
+    workload_key : str
+        The workload key of the SearchTask.
+
+    input_name : str
+        The name of input buffer.
+
+    input_data : Tensor

Review comment:
       ```suggestion
       input_data : tvm.nd.NDArray
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to