merrymercy commented on a change in pull request #7313: URL: https://github.com/apache/tvm/pull/7313#discussion_r586075755
########## File path: python/tvm/auto_scheduler/search_task.py ########## @@ -157,6 +164,149 @@ def __init__( ) +# The map stores special registered buffer for measurement +# This can be used for sparse workloads when we cannot use random tensors for measurment. +# { +# "workload_key_0": { +# "task_input_0": Tensor(...), +# "task_input_1": Tensor(...) +# }, +# "workload_key_1": { +# "task_input_2": Tensor(...), +# "task_input_3": Tensor(...) +# }, +# ... +# } +TASK_INPUT_BUFFER_TABLE = {} + + +def _save_buffer_to_file(buffer_name, buffer_data): + """Save the current Tensor buffer to a numpy file. + + File name will be: {buffer_name}.{buffer_shape}_{buffer_data_type} + """ + np_data = buffer_data.asnumpy() + + buffer_name += "." + for i in np_data.shape: + buffer_name += "%d_" % (i) + buffer_name += "%s" % (np_data.dtype) + + np_data.tofile(buffer_name, " ") + + +def _try_load_buffer_from_file(buffer_name): + """Try to load buffer from a numpy file, if not found, return None. + + File name has a same format as `_save_buffer_to_file`. + """ + filelist = os.listdir() + + for file in filelist: + if file.startswith(buffer_name) and file.count("."): + meta_info = file.split(".")[-1].split("_") + shape = [int(i) for i in meta_info[:-1]] + dtype = meta_info[-1] + buffer_data = np.fromfile(file, dtype=dtype, sep=" ") + buffer_data = buffer_data.reshape(shape) + return ndarray.array(buffer_data) + + return None + + +def register_task_input_buffer( + workload_key, + input_name, + input_data, + overwrite=False, + save_to_file=False, +): + """Register special buffer for measurement. + + Parameters + ---------- + workload_key : str + The workload key of the SearchTask. + + input_name : str + The name of input buffer. + + input_data : Tensor Review comment: ```suggestion input_data : tvm.nd.NDArray ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org