This is an automated email from the ASF dual-hosted git repository.

kevinthesun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 70e11d3  [Autotvm] Fix autotvm customized template (#5034)
70e11d3 is described below

commit 70e11d32cfe847a9536f7511b68a19ea23ac8221
Author: Haichen Shen <[email protected]>
AuthorDate: Thu Mar 12 10:01:22 2020 -0700

    [Autotvm] Fix autotvm customized template (#5034)
    
    * init
    
    * fix template
    
    * tweak naming
---
 python/tvm/autotvm/__init__.py                     |   2 +-
 python/tvm/autotvm/graph_tuner/base_graph_tuner.py |   2 +-
 python/tvm/autotvm/task/__init__.py                |   3 +-
 python/tvm/autotvm/task/task.py                    | 113 ++++++++++++++-------
 python/tvm/autotvm/task/topi_integration.py        |   8 +-
 tests/python/integration/test_tuning.py            |   2 +-
 tests/python/unittest/test_autotvm_common.py       |   4 +-
 .../unittest/test_autotvm_dispatch_context.py      |   2 +-
 tutorials/autotvm/tune_conv2d_cuda.py              |   2 +-
 tutorials/autotvm/tune_simple_template.py          |   4 +-
 tutorials/optimize/opt_matmul_auto_tensorcore.py   |   4 +-
 vta/tutorials/autotvm/tune_relay_vta.py            |   2 +-
 12 files changed, 94 insertions(+), 54 deletions(-)

diff --git a/python/tvm/autotvm/__init__.py b/python/tvm/autotvm/__init__.py
index eab4ddf..6b5fafc 100644
--- a/python/tvm/autotvm/__init__.py
+++ b/python/tvm/autotvm/__init__.py
@@ -42,7 +42,7 @@ from .measure import measure_option, MeasureInput, 
MeasureResult, MeasureErrorNo
     LocalBuilder, LocalRunner, RPCRunner
 from .tuner import callback
 from .task import get_config, create, ConfigSpace, ConfigEntity, \
-    register_topi_compute, register_topi_schedule, register_customized_task, \
+    register_topi_compute, register_topi_schedule, template, \
     DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \
     ApplyGraphBest as apply_graph_best
 from .env import GLOBAL_SCOPE
diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py 
b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py
index c6b79fa..bb9c52d 100644
--- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py
+++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py
@@ -42,7 +42,7 @@ def get_infer_layout(task_name):
         return topi.nn.depthwise_conv2d_infer_layout
     raise ValueError("Cannot find infer layout for task %s" % task_name)
 
[email protected]_customized_task("layout_transform")
[email protected]("layout_transform")
 def layout_transform(*args):
     """Autotvm layout transform template."""
     cfg = get_config()
diff --git a/python/tvm/autotvm/task/__init__.py 
b/python/tvm/autotvm/task/__init__.py
index 29313d4..7e18fca 100644
--- a/python/tvm/autotvm/task/__init__.py
+++ b/python/tvm/autotvm/task/__init__.py
@@ -22,8 +22,7 @@ This module defines the task data structure, as well as a 
collection(zoo)
 of typical tasks of interest.
 """
 
-from .task import Task, create, get_config, args_to_workload, \
-    register_customized_task
+from .task import Task, create, get_config, args_to_workload, template
 from .space import ConfigSpace, ConfigEntity
 from .code_hash import attach_code_hash, attach_code_hash_to_arg
 from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, \
diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py
index c75105b..ddee149 100644
--- a/python/tvm/autotvm/task/task.py
+++ b/python/tvm/autotvm/task/task.py
@@ -186,25 +186,35 @@ class Task(object):
 
 TASK_TABLE = {}
 
-class TopiTemplate(object):
-    """Topi template that holds the topi compute and schedule function"""
+class TaskTemplate(object):
+    """
+    Task template is used to creates a tunable AutoTVM task.
+
+    It can be defined by a pair of compute and schedule function using
+    `_register_task_compute` and `_register_task_schedule`,
+    or by a customized task creation function that is more flexible using
+    `_register_customized_task`.
+
+    Note that when customized func is registered, compute and schedule function
+    will be ignored
+    """
     def __init__(self):
-        self.compute = None
-        self.schedule = None
-        self.customized_func = None
+        self.fcompute = None
+        self.fschedule = None
+        self.fcustomized = None
 
     def __call__(self, *args, **kwargs):
         args = deserialize_args(args)
-        if self.customized_func is None:
+        if self.fcustomized is None:
             return self._default_func(*args, **kwargs)
-        assert callable(self.customized_func)
-        return self.customized_func(*args, **kwargs)
+        assert callable(self.fcustomized)
+        return self.fcustomized(*args, **kwargs)
 
     def _default_func(self, *args, **kwargs):
-        assert callable(self.compute) and callable(self.schedule)
-        out = self.compute(*args, **kwargs)
+        assert callable(self.fcompute) and callable(self.fschedule)
+        out = self.fcompute(*args, **kwargs)
         arg_bufs = [out] + self.get_inputs(out)
-        s = self.schedule([out])
+        s = self.fschedule([out])
         return s, arg_bufs
 
     def get_inputs(self, out):
@@ -218,7 +228,7 @@ class TopiTemplate(object):
                 queue.extend(t.op.input_tensors)
         return inputs
 
-def register_task_compute(name, func=None):
+def _register_task_compute(name, func=None):
     """Register compute function to autotvm task
 
     Parameters
@@ -237,17 +247,17 @@ def register_task_compute(name, func=None):
     """
     def _do_reg(f):
         if name not in TASK_TABLE:
-            TASK_TABLE[name] = TopiTemplate()
+            TASK_TABLE[name] = TaskTemplate()
         tmpl = TASK_TABLE[name]
-        if tmpl.compute is not None:
+        if tmpl.fcompute is not None:
             raise ValueError("Compute is already registered in autoTVM task 
%s" % name)
-        tmpl.compute = f
+        tmpl.fcompute = f
         return f
     if func:
         return _do_reg(func)
     return _do_reg
 
-def register_task_schedule(name, func=None):
+def _register_task_schedule(name, func=None):
     """Register schedule function to autotvm task
 
     Parameters
@@ -266,24 +276,19 @@ def register_task_schedule(name, func=None):
     """
     def _do_reg(f):
         if name not in TASK_TABLE:
-            TASK_TABLE[name] = TopiTemplate()
+            TASK_TABLE[name] = TaskTemplate()
         tmpl = TASK_TABLE[name]
-        if tmpl.schedule is not None:
+        if tmpl.fschedule is not None:
             raise ValueError("Schedule is already registered in autoTVM task 
%s" % name)
-        tmpl.schedule = f
+        tmpl.fschedule = f
         return f
     if func:
         return _do_reg(func)
     return _do_reg
 
-def register_customized_task(name, func=None):
+def _register_customized_task(name, func=None):
     """Register a customized function to AutoTVM task.
 
-    In most cases, you can just use register_topi_compute and 
register_topi_schedule
-    with the same task name to define an AutoTVM task. However, you can also
-    create a customized AutoTVM task that defines a tunable template or 
performs
-    extra layout transform before invoking compute/schedule function.
-
     Parameters
     ----------
     name: str
@@ -297,6 +302,37 @@ def register_customized_task(name, func=None):
     -------
     decorator: callable
         A decorator
+    """
+    def _do_reg(f):
+        if name not in TASK_TABLE:
+            TASK_TABLE[name] = TaskTemplate()
+        tmpl = TASK_TABLE[name]
+        if tmpl.fcustomized is not None:
+            raise ValueError("Customized func is already registered in autoTVM 
task %s" % name)
+        tmpl.fcustomized = f
+        return f
+    if func:
+        return _do_reg(func)
+    return _do_reg
+
+
+def template(task_name, func=None):
+    """Decorate a function as a tunable schedule template.
+
+    Parameters
+    ----------
+    task_name: str
+        The task name
+
+    func: None or callable
+        A callable template function.
+        If it is None, return a decorator.
+        If is callable, decorate this function.
+
+    Returns
+    -------
+    func: callable
+        The decorated function
 
     Examples
     --------
@@ -304,7 +340,7 @@ def register_customized_task(name, func=None):
 
     .. code-block:: python
 
-        @autotvm.register_customized_task("matmul")
+        @autotvm.template("matmul")
         def matmul(N, L, M, dtype):
             A = te.placeholder((N, L), name='A', dtype=dtype)
             B = te.placeholder((L, M), name='B', dtype=dtype)
@@ -331,17 +367,22 @@ def register_customized_task(name, func=None):
 
             return s, [A, B, C]
     """
-    def _do_reg(f):
-        if name not in TASK_TABLE:
-            TASK_TABLE[name] = TopiTemplate()
-        tmpl = TASK_TABLE[name]
-        if tmpl.customized_func is not None:
-            raise ValueError("Customized func is already registered in autoTVM 
task %s" % name)
-        tmpl.customized_func = f
-        return f
+    def _decorate(f):
+        def wrapper(*args, **kwargs):
+            assert not kwargs, "Do not support kwargs in template function 
call"
+            workload = args_to_workload(args, task_name)
+            tgt = _target.Target.current()
+            cfg = DispatchContext.current.query(tgt, workload)
+            with ApplyConfig(cfg):
+                return f(*args, **kwargs)
+
+        _register_customized_task(task_name, f)
+        return wrapper
+
     if func:
-        return _do_reg(func)
-    return _do_reg
+        return _decorate(func)
+    return _decorate
+
 
 def create(task_name, args, target, target_host=None):
     """Create a tuning task and initialize its search space
diff --git a/python/tvm/autotvm/task/topi_integration.py 
b/python/tvm/autotvm/task/topi_integration.py
index e1c0913..67f9780 100644
--- a/python/tvm/autotvm/task/topi_integration.py
+++ b/python/tvm/autotvm/task/topi_integration.py
@@ -30,8 +30,8 @@ import tvm.te._ffi_api
 from tvm import target as _target
 from tvm.te import tensor
 
-from .task import args_to_workload, DispatchContext, \
-    register_task_compute, register_task_schedule, serialize_args
+from .task import args_to_workload, serialize_args, DispatchContext, \
+    _register_task_compute, _register_task_schedule
 
 
 # Task extractor for relay program
@@ -142,7 +142,7 @@ def register_topi_compute(task_name, func=None):
     See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
     """
     def _decorate(topi_compute):
-        @register_task_compute(task_name)
+        @_register_task_compute(task_name)
         def wrapper(*args, **kwargs):
             """wrapper function for topi compute"""
             assert not kwargs, "Do not support kwargs in template function 
call"
@@ -212,7 +212,7 @@ def register_topi_schedule(task_name, func=None):
     See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
     """
     def _decorate(topi_schedule):
-        @register_task_schedule(task_name)
+        @_register_task_schedule(task_name)
         def wrapper(outs, *args, **kwargs):
             """wrapper function for topi schedule"""
             workload = get_workload(outs)
diff --git a/tests/python/integration/test_tuning.py 
b/tests/python/integration/test_tuning.py
index 60a372c..95b94f6 100644
--- a/tests/python/integration/test_tuning.py
+++ b/tests/python/integration/test_tuning.py
@@ -26,7 +26,7 @@ from tvm import te
 from tvm import autotvm
 from tvm.autotvm.tuner import RandomTuner
 
[email protected]_customized_task("testing/conv2d_no_batching")
[email protected]("testing/conv2d_no_batching")
 def conv2d_no_batching(N, H, W, CI, CO, KH, KW):
     """An example template for testing"""
     assert N == 1, "Only consider batch_size = 1 in this template"
diff --git a/tests/python/unittest/test_autotvm_common.py 
b/tests/python/unittest/test_autotvm_common.py
index a2f9b1d..909dbbc 100644
--- a/tests/python/unittest/test_autotvm_common.py
+++ b/tests/python/unittest/test_autotvm_common.py
@@ -37,7 +37,7 @@ class DummyRunner(Runner):
     def get_build_kwargs(self):
         return {}
 
[email protected]_customized_task("testing/matmul")
[email protected]("testing/matmul")
 def matmul(N, L, M, dtype):
     A = te.placeholder((N, L), name='A', dtype=dtype)
     B = te.placeholder((L, M), name='B', dtype=dtype)
@@ -64,7 +64,7 @@ def matmul(N, L, M, dtype):
 
     return s, [A, B, C]
 
[email protected]_customized_task("testing/bad_matmul")
[email protected]("testing/bad_matmul")
 def bad_matmul(N, L, M, dtype):
     if 'bad_device' in tvm.target.Target.current().keys:
         A = te.placeholder((N, L), name='A', dtype=dtype)
diff --git a/tests/python/unittest/test_autotvm_dispatch_context.py 
b/tests/python/unittest/test_autotvm_dispatch_context.py
index 5a55c4f..8b073c0 100644
--- a/tests/python/unittest/test_autotvm_dispatch_context.py
+++ b/tests/python/unittest/test_autotvm_dispatch_context.py
@@ -22,7 +22,7 @@ from tvm import autotvm
 
 def test_fallback():
 
-    @autotvm.register_customized_task("testing/dispatch/fallback")
+    @autotvm.template("testing/dispatch_fallback")
     def simple_template(a, b):
         cfg = autotvm.get_config()
         assert cfg.is_fallback
diff --git a/tutorials/autotvm/tune_conv2d_cuda.py 
b/tutorials/autotvm/tune_conv2d_cuda.py
index 260cf5a..3cdbb84 100644
--- a/tutorials/autotvm/tune_conv2d_cuda.py
+++ b/tutorials/autotvm/tune_conv2d_cuda.py
@@ -79,7 +79,7 @@ from tvm import autotvm
 # can be very large (at the level of 10^9 for some input shapes)
 #
 
[email protected]_customized_task("tutorial/conv2d_no_batching")
[email protected]("tutorial/conv2d_no_batching")
 def conv2d_no_batching(N, H, W, CO, CI, KH, KW, stride, padding):
     assert N == 1, "Only consider batch_size = 1 in this template"
 
diff --git a/tutorials/autotvm/tune_simple_template.py 
b/tutorials/autotvm/tune_simple_template.py
index dd3b9dc..c5a3843 100644
--- a/tutorials/autotvm/tune_simple_template.py
+++ b/tutorials/autotvm/tune_simple_template.py
@@ -103,7 +103,7 @@ def matmul_v0(N, L, M, dtype):
 # In autotvm, we can define a tunable parameter, or a "knob" for such kind of 
value.
 
 # Matmul V1: List candidate values
[email protected]_customized_task("tutorial/matmul_v1")  # 1. use a decorator
[email protected]("tutorial/matmul_v1")  # 1. use a decorator
 def matmul_v1(N, L, M, dtype):
     A = te.placeholder((N, L), name='A', dtype=dtype)
     B = te.placeholder((L, M), name='B', dtype=dtype)
@@ -183,7 +183,7 @@ def matmul_v1(N, L, M, dtype):
 # When the high level API cannot meet your requirement, you can always fall
 # back to use low level API.
 
[email protected]_customized_task("tutorial/matmul")
[email protected]("tutorial/matmul")
 def matmul(N, L, M, dtype):
     A = te.placeholder((N, L), name='A', dtype=dtype)
     B = te.placeholder((L, M), name='B', dtype=dtype)
diff --git a/tutorials/optimize/opt_matmul_auto_tensorcore.py 
b/tutorials/optimize/opt_matmul_auto_tensorcore.py
index 490ccdb..aae1333 100644
--- a/tutorials/optimize/opt_matmul_auto_tensorcore.py
+++ b/tutorials/optimize/opt_matmul_auto_tensorcore.py
@@ -95,7 +95,7 @@ def matmul_nn(A, B, L, dtype='float16', layout='NN'):
 #
 # We use AutoTVM to search for best configurations in this schedule.
 
[email protected]_customized_task("tutorial/test_gemm")
[email protected]("tutorial/auto_tensorcore/test_gemm")
 def test_gemm(N, L, M, dtype, layout):
     if (layout == "NN"):
       shape_a = (N, L)
@@ -265,7 +265,7 @@ elif dtype == 'int4' or dtype == 'int1':
   assert(major == 7 and minor == 5 and layout == 'TN')
 
 def tune_and_evaluate(M, N, L, dtype, layout):
-  task = autotvm.task.create("tutorial/test_gemm", args=(N, L, M, dtype, 
layout),
+  task = autotvm.task.create("tutorial/auto_tensorcore/test_gemm", args=(N, L, 
M, dtype, layout),
                              target='cuda')
   print(task.config_space)
 
diff --git a/vta/tutorials/autotvm/tune_relay_vta.py 
b/vta/tutorials/autotvm/tune_relay_vta.py
index 1d19e5d..c31d8cc 100644
--- a/vta/tutorials/autotvm/tune_relay_vta.py
+++ b/vta/tutorials/autotvm/tune_relay_vta.py
@@ -310,7 +310,7 @@ def register_vta_tuning_tasks():
     # init autotvm env to register VTA operator
     TaskExtractEnv()
 
-    @autotvm.register_customized_task("conv2d_packed.vta")
+    @autotvm.template("conv2d_packed.vta")
     def _topi_nn_conv2d(*args, **kwargs):
         assert not kwargs, "Do not support kwargs in template function call"
         A, W = args[:2]

Reply via email to