This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6141cac Replace RuntimeError in _lookup_task with deferred error.
(#8421)
6141cac is described below
commit 6141cac635fbdaad25b0f8ec3bce130e787922b5
Author: Matt Welsh (OctoML) <[email protected]>
AuthorDate: Fri Jul 9 14:56:34 2021 -0400
Replace RuntimeError in _lookup_task with deferred error. (#8421)
* Replace RuntimeError in _lookup_task with deferred error.
This allows unknown tasks to be created (e.g., when parsing
autotvm log files) but not invoked.
* Format.
* Update python/tvm/autotvm/task/task.py
Co-authored-by: Cody Yu <[email protected]>
Co-authored-by: Matt Welsh <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
---
python/tvm/autotvm/task/task.py | 29 ++++++++++++++++++++++++-----
1 file changed, 24 insertions(+), 5 deletions(-)
diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py
index 3097c29..ee17508 100644
--- a/python/tvm/autotvm/task/task.py
+++ b/python/tvm/autotvm/task/task.py
@@ -40,11 +40,11 @@ from .space import ConfigSpace
def _lookup_task(name):
task = TASK_TABLE.get(name)
if task is None:
- raise RuntimeError(
- f"Could not find a registered function for the task {name}. It is "
- "possible that the function is registered in a python file which
was "
- "not imported in this run."
- )
+ # Unable to find the given task. This might be because we are
+ # creating a task based on a name that has not been imported.
+ # Rather than raising an exception here, we return a dummy
+ # task which cannot be invoked.
+ task = MissingTask(name)
return task
@@ -264,6 +264,25 @@ class TaskTemplate(object):
return inputs
+class MissingTask(TaskTemplate):
+ """
+ Dummy task template for a task lookup which cannot be resolved.
+ This can occur if the task being requested from _lookup_task()
+ has not been imported in this run.
+ """
+
+ def __init__(self, taskname: str):
+ super().__init__()
+ self._taskname = taskname
+
+ def __call__(self, *args, **kwargs):
+ raise RuntimeError(
+ f"Attempting to invoke a missing task {self._taskname}."
+ "It is possible that the function is registered in a "
+ "Python module that is not imported in this run, or the log is
out-of-date."
+ )
+
+
def _register_task_compute(name, func=None):
"""Register compute function to autotvm task