tqchen commented on a change in pull request #4644: Relay op strategy
URL: https://github.com/apache/incubator-tvm/pull/4644#discussion_r379955826
 
 

 ##########
 File path: python/tvm/autotvm/task/task.py
 ##########
 @@ -116,43 +181,134 @@ def __repr__(self):
             self.name, self.args, self.kwargs, self.workload
         )
 
-TASK_TABLE = {
-}
+TASK_TABLE = {}
+
+class TopiTemplate(object):
+    """Topi template that holds the topi compute and schedule function"""
+    def __init__(self):
+        self.compute = None
+        self.schedule = None
+        self.customized_func = None
+
+    def __call__(self, *args, **kwargs):
+        args = deserialize_args(args)
+        if self.customized_func is None:
+            return self._default_func(*args, **kwargs)
+        assert callable(self.customized_func)
+        return self.customized_func(*args, **kwargs)
+
+    def _default_func(self, *args, **kwargs):
+        assert callable(self.compute) and callable(self.schedule)
+        out = self.compute(*args, **kwargs)
+        arg_bufs = [out] + self.get_inputs(out)
+        s = self.schedule([out])
+        return s, arg_bufs
+
+    def get_inputs(self, out):
+        inputs = []
+        queue = [out]
+        while queue:
+            t = queue.pop(0)
+            if isinstance(t.op, tensor.PlaceholderOp):
+                inputs.append(t)
+            else:
+                queue.extend(t.op.input_tensors)
+        return inputs
 
-def register(name, func=None, override=False):
-    """Register a task function.
+def register_task_compute(name, func=None):
+    """Register compute function to autotvm task
 
     Parameters
     ----------
-    name : str
-        The name to identify the task.
-    func : callable
-        The function to be registered.
-    override : bool
-        Whether override existing registration.
+    name: str
+        The task name
+
+    func: None or callable
+        If it is None, return a decorator.
+        If is callable, decorate this function.
 
     Returns
     -------
-    func: callable
-        The registered function
+    decorator: callable
+        A decorator
     """
-    def _do_reg(myf):
-        if name in TASK_TABLE and not override:
-            raise ValueError(
-                "Key %s is already registered" % name)
-        TASK_TABLE[name] = myf
-        return myf
+    def _do_reg(f):
+        if name not in TASK_TABLE:
+            TASK_TABLE[name] = TopiTemplate()
+        tmpl = TASK_TABLE[name]
+        if tmpl.compute is not None:
+            raise ValueError("Compute is already registered in autoTVM task 
%s" % name)
+        tmpl.compute = f
+        return f
     if func:
         return _do_reg(func)
     return _do_reg
 
-def create(func_name, args, target, target_host=None, template_key=None):
+def register_task_schedule(name, func=None):
+    """Register schedule function to autotvm task
+
 
 Review comment:
   add a code example section here

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to