tkonolige commented on a change in pull request #8492:
URL: https://github.com/apache/tvm/pull/8492#discussion_r687938717



##########
File path: python/tvm/auto_scheduler/measure.py
##########
@@ -817,21 +818,66 @@ def prepare_input_map(args):
     return tensor_input_map
 
 
+def prepare_runner_args(inp, build_res):
+    """This function prepares the pre-defined arguments in 
`TASK_INPUT_BUFFER_TABLE` for local/rpc
+    runner in main process
+
+    Parameters
+    ----------
+    inp : MeasureInput
+        Measure input to be measured.
+
+    build_res : BuildResult
+        Build result to be measured.
+
+    Returns
+    -------
+    List[Optional[numpy.ndarray]] :
+        List of arguments for running the program. If the argument does not 
have a pre-defined input
+        buffer, None is added to the list as a placeholder.
+
+    """
+    # pylint: disable=import-outside-toplevel
+    from .search_task import get_task_input_buffer  # lazily import to avoid 
recursive dependency
+
+    task_input_names = inp.task.task_input_names
+    tensor_input_map = prepare_input_map(build_res.args)
+    if not task_input_names:
+        tensor_input_map = {}
+    args = []
+    task_inputs_count = 0
+    for arg in build_res.args:
+        if arg in tensor_input_map:
+            tensor_name = tensor_input_map[arg]
+            if tensor_name in task_input_names:
+                task_input_buffer = 
get_task_input_buffer(inp.task.workload_key, tensor_name)
+                # convert tvm.NDArray to picklable numpy.ndarray
+                args.append(task_input_buffer.numpy())

Review comment:
       Oh, I missed that you tested with sparse. My bad!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to