icemelon9 commented on a change in pull request #4644: Relay op strategy
URL: https://github.com/apache/incubator-tvm/pull/4644#discussion_r379875256
 
 

 ##########
 File path: python/tvm/autotvm/task/topi_integration.py
 ##########
 @@ -76,250 +40,49 @@ class TaskExtractEnv:
     registered = None
 
     def __init__(self, allow_duplicate=False):
-        # pylint: disable=import-outside-toplevel
-        import topi
-
-        # topi compute -> autotvm task name
-        self.topi_to_task = {
-            topi.nn.conv2d: "topi_nn_conv2d",
-            topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw",
-            topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
-            topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
-            topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc",
-            topi.nn.conv2d_NCHWc_int8: "topi_x86_conv2d_NCHWc_int8",
-            topi.nn.dense: "topi_nn_dense",
-            topi.nn.batch_matmul: "topi_nn_batch_matmul",
-            topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
-            topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
-            topi.nn.bitserial_dense: "topi_nn_bitserial_dense",
-            topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
-            topi.nn.conv1d_transpose_ncw: "topi_nn_conv1d_transpose_ncw",
-            topi.nn.conv3d: "topi_nn_conv3d",
-        }
-
-        self.topi_to_schedule = {
-            topi.nn.conv2d: [topi.generic.schedule_conv2d_nchw,
-                             topi.generic.schedule_conv2d_nhwc],
-            topi.nn.depthwise_conv2d_nchw: 
[topi.generic.schedule_depthwise_conv2d_nchw,
-                                            
topi.generic.schedule_depthwise_conv2d_nhwc],
-            topi.nn.group_conv2d_nchw: 
[topi.generic.schedule_group_conv2d_nchw],
-            topi.nn.conv2d_transpose_nchw: 
[topi.generic.schedule_conv2d_transpose_nchw],
-            topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc],
-            topi.nn.conv2d_NCHWc_int8: 
[topi.generic.schedule_conv2d_NCHWc_int8],
-            topi.nn.dense: [topi.generic.schedule_dense],
-            topi.nn.batch_matmul: [topi.generic.schedule_batch_matmul],
-            topi.nn.bitserial_conv2d_nchw: 
[topi.generic.schedule_bitserial_conv2d_nchw],
-            topi.nn.bitserial_conv2d_nhwc: 
[topi.generic.schedule_bitserial_conv2d_nhwc],
-            topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense],
-            topi.nn.deformable_conv2d_nchw: 
[topi.generic.schedule_deformable_conv2d_nchw],
-            topi.nn.conv1d_transpose_ncw: 
[topi.generic.schedule_conv1d_transpose_ncw],
-            topi.nn.conv3d: [topi.generic.schedule_conv3d_ndhwc],
-        }
-
-        # function reflection for tracing
-        self.func_to_reflection = {
-            topi.nn.conv2d:                 lambda x: setattr(topi.nn, 
'conv2d', x),
-            topi.nn.conv2d_NCHWc:           lambda x: setattr(topi.nn, 
'conv2d_NCHWc', x),
-            topi.nn.conv2d_NCHWc_int8:      lambda x: setattr(topi.nn, 
'conv2d_NCHWc_int8', x),
-            topi.nn.depthwise_conv2d_nchw:  lambda x: setattr(topi.nn, 
'depthwise_conv2d_nchw', x),
-            topi.nn.group_conv2d_nchw:      lambda x: setattr(topi.nn, 
'group_conv2d_nchw', x),
-            topi.nn.conv2d_transpose_nchw:  lambda x: setattr(topi.nn, 
'conv2d_transpose_nchw', x),
-            topi.nn.dense:                  lambda x: setattr(topi.nn, 
'dense', x),
-            topi.nn.batch_matmul:           lambda x: setattr(topi.nn, 
'batch_matmul', x),
-            topi.nn.bitserial_conv2d_nchw:  lambda x: setattr(topi.nn, 
'bitserial_conv2d_nchw', x),
-            topi.nn.bitserial_conv2d_nhwc:  lambda x: setattr(topi.nn, 
'bitserial_conv2d_nhwc', x),
-            topi.nn.bitserial_dense:        lambda x: setattr(topi.nn, 
'bitserial_dense', x),
-            topi.nn.deformable_conv2d_nchw: lambda x: setattr(topi.nn, 
'deformable_conv2d_nchw', x),
-            topi.nn.conv1d_transpose_ncw:   lambda x: setattr(topi.nn, 
'conv1d_transpose_ncw', x),
-            topi.nn.conv3d:                 lambda x: setattr(topi.nn, 
'conv3d', x),
-        }
-
         self.allow_duplicate = allow_duplicate
-        self._register_topi_task()
         self.task_collection = []
-        self.wanted_topi_funcs = list(self.topi_to_task.keys())
+        self.wanted_relay_ops = None
         self.modified_funcs = []
+        self.tracing = False
 
     def __enter__(self):
         self.task_collection = []
-        self.modified_funcs = []
-
-        for topi_compute in self.wanted_topi_funcs:
-            def _local_scope(compute_func):
-                """start a scope to hold the local function in for loop"""
-
-                def _tracing_wrapper(*args, **kwargs):
-                    assert not kwargs, "Do not support extracting tuning tasks 
when " \
-                                       "kwargs is used in TOPI function call. 
" \
-                                       "Please modify it to use only 
positional args."
-                    key = (self.topi_to_task[compute_func], 
serialize_args(args))
-                    if self.allow_duplicate or key not in self.task_collection:
-                        self.task_collection.append(key)
-
-                    return compute_func(*args, **kwargs)
-
-                self.func_to_reflection[compute_func](_tracing_wrapper)
-                self.modified_funcs.append(compute_func)
-
-            _local_scope(topi_compute)
+        self.tracing = True
 
         return self
 
     def __exit__(self, exc_type, exc_val, exc_tb):
-        # revert modification
-        for func in self.modified_funcs:
-            self.func_to_reflection[func](func)
-
-    def _register_topi_task(self):
-        """register tuning wrapper for topi function"""
-        # pylint: disable=import-outside-toplevel
-        import topi
-
-        # Avoid double registration for certain targets
-        if TaskExtractEnv.registered:
-            return
-        TaskExtractEnv.registered = True
-
-        # Tuning wrapper for topi functions
-        @register("topi_nn_conv2d")
-        def _topi_nn_conv2d(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function 
call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            layout = args[-2]
-            C = topi.nn.conv2d(*args, **kwargs)
-            if layout == 'NCHW':
-                s = topi.generic.schedule_conv2d_nchw([C])
-            elif layout == 'HWCN':
-                s = topi.generic.schedule_conv2d_hwcn([C])
-            elif layout == 'NHWC':
-                s = topi.generic.schedule_conv2d_nhwc([C])
-            else:
-                raise ValueError("Unsupported layout {}".format(layout))
-            return s, [A, W, C]
-
-        @register("topi_nn_depthwise_conv2d_nchw")
-        def _topi_nn_depthwise_conv2d_nchw(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function 
call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.depthwise_conv2d_nchw(*args, **kwargs)
-            s = topi.generic.schedule_depthwise_conv2d_nchw([C])
-            return s, [A, W, C]
-
-        @register("topi_nn_group_conv2d_nchw")
-        def _topi_nn_group_conv2d_nchw(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function 
call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.group_conv2d_nchw(*args, **kwargs)
-            s = topi.generic.schedule_group_conv2d_nchw([C])
-            return s, [A, W, C]
-
-        @register("topi_nn_conv2d_transpose_nchw")
-        def _topi_nn_conv2d_transpose_nchw(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function 
call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.conv2d_transpose_nchw(*args, **kwargs)
-            s = topi.generic.schedule_conv2d_transpose_nchw([C])
-            return s, [A, W, C]
-
-        @register("topi_nn_conv1d_transpose_ncw")
-        def _topi_nn_conv1d_transpose_ncw(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function 
call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.conv1d_transpose_ncw(*args, **kwargs)
-            s = topi.generic.schedule_conv1d_transpose_ncw([C])
-            return s, [A, W, C]
-
-        @register("topi_nn_conv3d")
-        def _topi_nn_conv3d(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function 
call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.conv3d(*args, **kwargs)
-            s = topi.generic.schedule_conv3d_ndhwc([C])
-            return s, [A, W, C]
-
-        @register("topi_nn_dense")
-        def _topi_nn_dense(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function 
call"
-            args = deserialize_args(args)
-            if len(args) > 2:
-                data, weight, bias = args[:3]
-            else:
-                data, weight = args
-                bias = None
-            C = topi.nn.dense(*args, **kwargs)
-            s = topi.generic.schedule_dense([C])
-            if bias is not None:
-                return s, [data, weight, bias, C]
-            return s, [data, weight, C]
-
-        @register("topi_nn_batch_matmul")
-        def _topi_nn_batch_matmul(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function 
call"
-            args = deserialize_args(args)
-            A, B = args
-            C = topi.nn.batch_matmul(A, B)
-            s = topi.generic.schedule_batch_matmul([C])
-            return s, [A, B, C]
-
-        @register("topi_nn_bitserial_conv2d_nhwc")
-        def _topi_bitserial_conv2d_nhwc(*args, **kwargs):
-            args = deserialize_args(args)
-            C = topi.nn.bitserial_conv2d_nhwc(*args, **kwargs)
-            s = topi.generic.nn.schedule_bitserial_conv2d_nhwc([C])
-            A, W = args[:2]
-            return s, [A, W, C]
-
-        @register("topi_nn_bitserial_conv2d_nchw")
-        def _topi_bitserial_conv2d_nchw(*args, **kwargs):
-            args = deserialize_args(args)
-            C = topi.nn.bitserial_conv2d_nchw(*args, **kwargs)
-            s = topi.generic.nn.schedule_bitserial_conv2d_nchw([C])
-            A, W = args[:2]
-            return s, [A, W, C]
-
-        @register("topi_nn_bitserial_dense")
-        def _topi_nn_bitserial_dense(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function 
call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.bitserial_dense(*args, **kwargs)
-            s = topi.generic.schedule_bitserial_dense([C])
-            return s, [A, W, C]
-
-        @register("topi_nn_deformable_conv2d_nchw")
-        def _topi_nn_deformable_conv2d_nchw(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function 
call"
-            args = deserialize_args(args)
-            A, Offset, W = args[:3]
-            C = topi.nn.deformable_conv2d_nchw(*args, **kwargs)
-            s = topi.generic.schedule_deformable_conv2d_nchw([C])
-            return s, [A, Offset, W, C]
-
-        @register("topi_nn_conv2d_NCHWc")
-        def _topi_nn_conv2d_NCHWc(*args, **kwargs):
-            assert not kwargs, "Do not support kwargs in template function 
call"
-            args = deserialize_args(args)
-            A, W = args[:2]
-            C = topi.nn.conv2d_NCHWc(*args, **kwargs)
-            s = topi.generic.schedule_conv2d_NCHWc([C])
-            return s, [A, W, C]
+        self.tracing = False
 
-    def reset(self, wanted_topi_funcs):
+    def reset(self, wanted_relay_ops=None):
         """Reset task collections
 
         Parameters
         ----------
-        wanted_topi_funcs: List of function
-            The topi function to be extracted
+        wanted_relay_ops: List of relay.op.Op
+            The relay ops to be extracted
         """
         self.task_collection = []
-        self.wanted_topi_funcs = wanted_topi_funcs
+        self.wanted_relay_ops = wanted_relay_ops
+
+    def add_task(self, task_name, args):
+        """Add AutoTVM task
+
+        Parameters
+        ----------
+        task_name: str
+            AutoTVM task name.
+
+        args: tuple
+            Arguments to the TOPI function.
+
+        cond: SpecializedCondition
 
 Review comment:
   removed

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to