jcf94 commented on a change in pull request #7313:
URL: https://github.com/apache/tvm/pull/7313#discussion_r587996929
##########
File path: python/tvm/auto_scheduler/measure.py
##########
@@ -719,6 +720,45 @@ def local_builder_build(inputs, timeout, n_parallel,
build_func="default", verbo
return results
+def _prepare_input_map(args):
+ """This function deals with special task inputs.
+
+ Parameters
+ ----------
+ args : List[Tensor]
+ Input/output Tensor of a TVM subgraph.
+
+ Returns
+ -------
+ A Dict[Tensor, str] that maps the input Tensor to a buffer name.
+
+ Note
+ ----
+ The buffer name is specially designed, and these buffer should be provided
in
+ `SearchTask(..., task_inputs={...})`.
+ """
+ # pylint: disable=import-outside-toplevel
+ from tvm import topi # lazily import to avoid recursive dependency
+
+ # A dict that maps the input tensor arg to a buffer name
+ tensor_input_map = {}
+
+ # Case 0: Check placeholder name
+ for arg in args:
+ if isinstance(arg.op, tvm.te.PlaceholderOp):
+ if arg.op.name != "placeholder":
+ tensor_input_map[arg] = arg.op.name
+
+ # Case 1: Check sparse op
+ sparse_input_map = topi.nn.sparse.try_get_sparse_input(args)
Review comment:
> Could we associate the lookup mechanism with `@register_workload`? It
would at least be extensible then.
Thanks! This is a pretty good idea, I'll have a try.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]