cee1 commented on a change in pull request #10650:
URL: https://github.com/apache/tvm/pull/10650#discussion_r829712150
##########
File path: python/tvm/autotvm/task/relay_integration.py
##########
@@ -151,11 +160,66 @@ def extract_from_multiple_program(mods, params, target,
target_host=None, ops=No
# create tasks for target
tasks = []
- for task_name, args in env.get_tasks():
+ for task_name, subgraph_name, args in env.get_tasks():
try:
- tsk = create(task_name, args, target=target)
+ if GLOBAL_SCOPE.tune_subgraph:
+ if subgraph_name is not None:
+ # If tuning subgraph, use subgraph_name as task_name
+ tsk = create(subgraph_name, args, target=target)
+ else:
+ continue
+ else:
+ tsk = create(task_name, args, target=target)
tasks.append(tsk)
except topi.InvalidShapeError:
logger.warning("Invalid shape during AutoTVM task creation")
return tasks
+
+
+@tvm._ffi.register_func("auto_tvm.relay_integration.is_tune_subgraph")
+def is_tune_subgraph():
+ return GLOBAL_SCOPE.tune_subgraph
+
+
+@tvm._ffi.register_func("auto_tvm.relay_integration.register_subgraph_task")
+def register_subgraph_task(subgraph_name, outs, best_impl_name):
+ """In task extracting phase this function registers subgraph as tunable
autotvm
+ task and warp returning outputs with "workload" attached as subgraph's
topi compute.
+ In building phase it attaches "workload" to outputs and returns them to
te_compiler_cache.
+
+ Parameters
+ ----------
+ subgraph_name: str
+ The name of the subgraph
+ outs: list
+ The outputs of the subgraph
+ best_impl_name: str
+ The anchor implementation name of the subgraph
+
+ Returns
+ -------
+ outs: list
+ The subgraph's outputs with workload attached to the last op.
+ """
+ global g_registered_subgraphs_extracted
+ env = TaskExtractEnv.get()
+ tunable_op_list = [item[0] for item in env.get_tasks()]
+ if GLOBAL_SCOPE.tune_subgraph:
+ # use inputs+outs as identifier of subgraph task
+ args = _traverse_to_get_io_tensors(outs)
+ if (
+ env is not None and env.tracing and best_impl_name in
tunable_op_list
+ ): # extract task phase
+ # remove vm compiler prefix
+ subgraph_name = re.sub(r"vm_mod_", "", subgraph_name)
Review comment:
`re.sub(r"vm_mod_", ...` -> `re.sub(r"^vm_mod_", ...`
##########
File path: python/tvm/autotvm/task/task.py
##########
@@ -48,6 +50,105 @@ def _lookup_task(name):
return task
+def _get_compute(name):
+ """get compute by given name.
+
+ Parameters
+ ----------
+ name: name of compute
+ """
+ task = _lookup_task(name)
+ return task.fcompute
+
+
+def format_subgraph_task_name(name, args):
+ if re.match(r"\w+\d+$", name):
+ # remove number of reduplicative subgraph name like '_1'
+ name = re.sub(r"[_]\d+$", "", name)
Review comment:
`re.sub(r"[_]\d+$", ...)` -> `re.sub(r"_\d+$", ...)`
##########
File path: python/tvm/autotvm/task/task.py
##########
@@ -48,6 +50,105 @@ def _lookup_task(name):
return task
+def _get_compute(name):
+ """get compute by given name.
+
+ Parameters
+ ----------
+ name: name of compute
+ """
+ task = _lookup_task(name)
+ return task.fcompute
+
+
+def format_subgraph_task_name(name, args):
+ if re.match(r"\w+\d+$", name):
+ # remove number of reduplicative subgraph name like '_1'
+ name = re.sub(r"[_]\d+$", "", name)
+ str_args = format_args(args).encode("utf-8")
+ hash_key = hashlib.md5(str_args).hexdigest()
+ return name + "_" + hash_key
+
+
+def format_args(args):
+ """format arguments of a topi function to a string
+
+ Parameters
+ ----------
+ args: list of hashable or Tensor
+ """
+
+ def _encode(x):
Review comment:
What about use `serialize_args(...)` instead? Which will return a tuple,
and then try to hash that tuple (as part of subgraph key)
##########
File path: python/tvm/autotvm/task/relay_integration.py
##########
@@ -151,11 +160,66 @@ def extract_from_multiple_program(mods, params, target,
target_host=None, ops=No
# create tasks for target
tasks = []
- for task_name, args in env.get_tasks():
+ for task_name, subgraph_name, args in env.get_tasks():
try:
- tsk = create(task_name, args, target=target)
+ if GLOBAL_SCOPE.tune_subgraph:
+ if subgraph_name is not None:
+ # If tuning subgraph, use subgraph_name as task_name
+ tsk = create(subgraph_name, args, target=target)
+ else:
+ continue
+ else:
+ tsk = create(task_name, args, target=target)
tasks.append(tsk)
except topi.InvalidShapeError:
logger.warning("Invalid shape during AutoTVM task creation")
return tasks
+
+
+@tvm._ffi.register_func("auto_tvm.relay_integration.is_tune_subgraph")
+def is_tune_subgraph():
+ return GLOBAL_SCOPE.tune_subgraph
+
+
+@tvm._ffi.register_func("auto_tvm.relay_integration.register_subgraph_task")
+def register_subgraph_task(subgraph_name, outs, best_impl_name):
+ """In task extracting phase this function registers subgraph as tunable
autotvm
+ task and warp returning outputs with "workload" attached as subgraph's
topi compute.
+ In building phase it attaches "workload" to outputs and returns them to
te_compiler_cache.
+
+ Parameters
+ ----------
+ subgraph_name: str
+ The name of the subgraph
+ outs: list
+ The outputs of the subgraph
+ best_impl_name: str
+ The anchor implementation name of the subgraph
+
+ Returns
+ -------
+ outs: list
+ The subgraph's outputs with workload attached to the last op.
+ """
+ global g_registered_subgraphs_extracted
+ env = TaskExtractEnv.get()
+ tunable_op_list = [item[0] for item in env.get_tasks()]
+ if GLOBAL_SCOPE.tune_subgraph:
+ # use inputs+outs as identifier of subgraph task
+ args = _traverse_to_get_io_tensors(outs)
+ if (
+ env is not None and env.tracing and best_impl_name in
tunable_op_list
+ ): # extract task phase
+ # remove vm compiler prefix
+ subgraph_name = re.sub(r"vm_mod_", "", subgraph_name)
+ subgraph_name = format_subgraph_task_name(subgraph_name, args)
+ if subgraph_name not in g_registered_subgraphs_extracted:
+ g_registered_subgraphs_extracted.append(subgraph_name)
+ return register_topi_subgraph(best_impl_name, args,
subgraph_name, outs)
+ elif best_impl_name in tunable_op_list: # build phase
+ # remove codegen prefix
+ subgraph_name = re.sub(r"tvmgen_default_", "", subgraph_name)
Review comment:
`re.sub(r"tvmgen_default_", ...)` -> `re.sub(r"^tvm_gen_default_", ...)`
##########
File path: python/tvm/autotvm/task/task.py
##########
@@ -48,6 +50,105 @@ def _lookup_task(name):
return task
+def _get_compute(name):
+ """get compute by given name.
+
+ Parameters
+ ----------
+ name: name of compute
+ """
+ task = _lookup_task(name)
+ return task.fcompute
+
+
+def format_subgraph_task_name(name, args):
+ if re.match(r"\w+\d+$", name):
+ # remove number of reduplicative subgraph name like '_1'
+ name = re.sub(r"[_]\d+$", "", name)
+ str_args = format_args(args).encode("utf-8")
+ hash_key = hashlib.md5(str_args).hexdigest()
+ return name + "_" + hash_key
+
+
+def format_args(args):
+ """format arguments of a topi function to a string
+
+ Parameters
+ ----------
+ args: list of hashable or Tensor
+ """
+
+ def _encode(x):
+ if isinstance(x, tensor.Tensor):
+ ret = ""
+ for s in get_const_tuple(x.shape):
+ ret = ret + str(s)
+ return ret + x.dtype
+ if isinstance(x, (tuple, list, container.Array)):
+ ret = ""
+ for a in x:
+ ret = ret + _encode(a)
+ return ret
+ if isinstance(x, (str, int, float, expr.Var, expr.Any)):
+ return str(x)
+ if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
+ return str(x.value)
+ if isinstance(x, runtime.container.String):
+ return str(x)
+ if x is None:
+ return None
+ raise RuntimeError(
+ 'Do not support type "%s" in argument. Consider to use'
+ "primitive types or tvm.tir.Var only" % type(x)
+ )
+
+ ret = ""
+ for t in args:
+ ret = ret + _encode(t)
+ return ret
+
+
+def _traverse_to_get_io_tensors(outs):
+ """Traverse from a list of output tensors to get input/output tensors.
+
+ Parameters
+ ----------
+ outs: List[Tensor]
+ The output tensors
+
+ Returns
+ -------
+ io_tensors: List[Tensor]
+ The input and output tensors with static shape
+ """
+ inputs = []
+ visited = set()
+
+ def traverse(t):
+ # We cannot directly add tensors to the set, because the comparison of
+ # two tensors with ndim=0 is ambiguous.
+ assert t.handle is not None
+ if t.handle.value in visited:
+ return
+ if isinstance(t.op, PlaceholderOp):
+ inputs.append(t)
+ elif isinstance(t.op, ComputeOp):
+ for x in t.op.input_tensors:
+ traverse(x)
+ visited.add(t.handle.value)
+
+ for t in outs:
+ traverse(t)
+
+ io_tensors = inputs + list(outs)
+ for t in io_tensors:
+ # Reject the compute if any of its I/O tensors has dynamic shape.
+ if any([not isinstance(v, int) for v in get_const_tuple(t.shape)]):
+ return []
Review comment:
What happens if returns [] here?
Or, may add a FIXME annotation
##########
File path: python/tvm/autotvm/task/task.py
##########
@@ -244,8 +345,12 @@ def __call__(self, *args, **kwargs):
def _default_func(self, *args, **kwargs):
assert callable(self.fcompute) and callable(self.fschedule)
out = self.fcompute(*args, **kwargs)
- arg_bufs = [out] + self._get_inputs(out)
- s = self.fschedule([out])
+ if GLOBAL_SCOPE.current.tune_subgraph:
+ arg_bufs = _traverse_to_get_io_tensors(out)
+ s = self.fschedule(out)
+ else:
+ arg_bufs = [out] + self._get_inputs(out)
Review comment:
Note: auf_bufs return by "Line 349", has "inputs" first, and then
follows "outputs"
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]