icemelon9 commented on a change in pull request #4644: Relay op strategy URL: https://github.com/apache/incubator-tvm/pull/4644#discussion_r379875256
########## File path: python/tvm/autotvm/task/topi_integration.py ########## @@ -76,250 +40,49 @@ class TaskExtractEnv: registered = None def __init__(self, allow_duplicate=False): - # pylint: disable=import-outside-toplevel - import topi - - # topi compute -> autotvm task name - self.topi_to_task = { - topi.nn.conv2d: "topi_nn_conv2d", - topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw", - topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw", - topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw", - topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc", - topi.nn.conv2d_NCHWc_int8: "topi_x86_conv2d_NCHWc_int8", - topi.nn.dense: "topi_nn_dense", - topi.nn.batch_matmul: "topi_nn_batch_matmul", - topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw", - topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc", - topi.nn.bitserial_dense: "topi_nn_bitserial_dense", - topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw", - topi.nn.conv1d_transpose_ncw: "topi_nn_conv1d_transpose_ncw", - topi.nn.conv3d: "topi_nn_conv3d", - } - - self.topi_to_schedule = { - topi.nn.conv2d: [topi.generic.schedule_conv2d_nchw, - topi.generic.schedule_conv2d_nhwc], - topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw, - topi.generic.schedule_depthwise_conv2d_nhwc], - topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw], - topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw], - topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc], - topi.nn.conv2d_NCHWc_int8: [topi.generic.schedule_conv2d_NCHWc_int8], - topi.nn.dense: [topi.generic.schedule_dense], - topi.nn.batch_matmul: [topi.generic.schedule_batch_matmul], - topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw], - topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc], - topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense], - topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw], - topi.nn.conv1d_transpose_ncw: [topi.generic.schedule_conv1d_transpose_ncw], - topi.nn.conv3d: [topi.generic.schedule_conv3d_ndhwc], - } - - # function reflection for tracing - self.func_to_reflection = { - topi.nn.conv2d: lambda x: setattr(topi.nn, 'conv2d', x), - topi.nn.conv2d_NCHWc: lambda x: setattr(topi.nn, 'conv2d_NCHWc', x), - topi.nn.conv2d_NCHWc_int8: lambda x: setattr(topi.nn, 'conv2d_NCHWc_int8', x), - topi.nn.depthwise_conv2d_nchw: lambda x: setattr(topi.nn, 'depthwise_conv2d_nchw', x), - topi.nn.group_conv2d_nchw: lambda x: setattr(topi.nn, 'group_conv2d_nchw', x), - topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x), - topi.nn.dense: lambda x: setattr(topi.nn, 'dense', x), - topi.nn.batch_matmul: lambda x: setattr(topi.nn, 'batch_matmul', x), - topi.nn.bitserial_conv2d_nchw: lambda x: setattr(topi.nn, 'bitserial_conv2d_nchw', x), - topi.nn.bitserial_conv2d_nhwc: lambda x: setattr(topi.nn, 'bitserial_conv2d_nhwc', x), - topi.nn.bitserial_dense: lambda x: setattr(topi.nn, 'bitserial_dense', x), - topi.nn.deformable_conv2d_nchw: lambda x: setattr(topi.nn, 'deformable_conv2d_nchw', x), - topi.nn.conv1d_transpose_ncw: lambda x: setattr(topi.nn, 'conv1d_transpose_ncw', x), - topi.nn.conv3d: lambda x: setattr(topi.nn, 'conv3d', x), - } - self.allow_duplicate = allow_duplicate - self._register_topi_task() self.task_collection = [] - self.wanted_topi_funcs = list(self.topi_to_task.keys()) + self.wanted_relay_ops = None self.modified_funcs = [] + self.tracing = False def __enter__(self): self.task_collection = [] - self.modified_funcs = [] - - for topi_compute in self.wanted_topi_funcs: - def _local_scope(compute_func): - """start a scope to hold the local function in for loop""" - - def _tracing_wrapper(*args, **kwargs): - assert not kwargs, "Do not support extracting tuning tasks when " \ - "kwargs is used in TOPI function call. " \ - "Please modify it to use only positional args." - key = (self.topi_to_task[compute_func], serialize_args(args)) - if self.allow_duplicate or key not in self.task_collection: - self.task_collection.append(key) - - return compute_func(*args, **kwargs) - - self.func_to_reflection[compute_func](_tracing_wrapper) - self.modified_funcs.append(compute_func) - - _local_scope(topi_compute) + self.tracing = True return self def __exit__(self, exc_type, exc_val, exc_tb): - # revert modification - for func in self.modified_funcs: - self.func_to_reflection[func](func) - - def _register_topi_task(self): - """register tuning wrapper for topi function""" - # pylint: disable=import-outside-toplevel - import topi - - # Avoid double registration for certain targets - if TaskExtractEnv.registered: - return - TaskExtractEnv.registered = True - - # Tuning wrapper for topi functions - @register("topi_nn_conv2d") - def _topi_nn_conv2d(*args, **kwargs): - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - A, W = args[:2] - layout = args[-2] - C = topi.nn.conv2d(*args, **kwargs) - if layout == 'NCHW': - s = topi.generic.schedule_conv2d_nchw([C]) - elif layout == 'HWCN': - s = topi.generic.schedule_conv2d_hwcn([C]) - elif layout == 'NHWC': - s = topi.generic.schedule_conv2d_nhwc([C]) - else: - raise ValueError("Unsupported layout {}".format(layout)) - return s, [A, W, C] - - @register("topi_nn_depthwise_conv2d_nchw") - def _topi_nn_depthwise_conv2d_nchw(*args, **kwargs): - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - A, W = args[:2] - C = topi.nn.depthwise_conv2d_nchw(*args, **kwargs) - s = topi.generic.schedule_depthwise_conv2d_nchw([C]) - return s, [A, W, C] - - @register("topi_nn_group_conv2d_nchw") - def _topi_nn_group_conv2d_nchw(*args, **kwargs): - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - A, W = args[:2] - C = topi.nn.group_conv2d_nchw(*args, **kwargs) - s = topi.generic.schedule_group_conv2d_nchw([C]) - return s, [A, W, C] - - @register("topi_nn_conv2d_transpose_nchw") - def _topi_nn_conv2d_transpose_nchw(*args, **kwargs): - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - A, W = args[:2] - C = topi.nn.conv2d_transpose_nchw(*args, **kwargs) - s = topi.generic.schedule_conv2d_transpose_nchw([C]) - return s, [A, W, C] - - @register("topi_nn_conv1d_transpose_ncw") - def _topi_nn_conv1d_transpose_ncw(*args, **kwargs): - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - A, W = args[:2] - C = topi.nn.conv1d_transpose_ncw(*args, **kwargs) - s = topi.generic.schedule_conv1d_transpose_ncw([C]) - return s, [A, W, C] - - @register("topi_nn_conv3d") - def _topi_nn_conv3d(*args, **kwargs): - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - A, W = args[:2] - C = topi.nn.conv3d(*args, **kwargs) - s = topi.generic.schedule_conv3d_ndhwc([C]) - return s, [A, W, C] - - @register("topi_nn_dense") - def _topi_nn_dense(*args, **kwargs): - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - if len(args) > 2: - data, weight, bias = args[:3] - else: - data, weight = args - bias = None - C = topi.nn.dense(*args, **kwargs) - s = topi.generic.schedule_dense([C]) - if bias is not None: - return s, [data, weight, bias, C] - return s, [data, weight, C] - - @register("topi_nn_batch_matmul") - def _topi_nn_batch_matmul(*args, **kwargs): - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - A, B = args - C = topi.nn.batch_matmul(A, B) - s = topi.generic.schedule_batch_matmul([C]) - return s, [A, B, C] - - @register("topi_nn_bitserial_conv2d_nhwc") - def _topi_bitserial_conv2d_nhwc(*args, **kwargs): - args = deserialize_args(args) - C = topi.nn.bitserial_conv2d_nhwc(*args, **kwargs) - s = topi.generic.nn.schedule_bitserial_conv2d_nhwc([C]) - A, W = args[:2] - return s, [A, W, C] - - @register("topi_nn_bitserial_conv2d_nchw") - def _topi_bitserial_conv2d_nchw(*args, **kwargs): - args = deserialize_args(args) - C = topi.nn.bitserial_conv2d_nchw(*args, **kwargs) - s = topi.generic.nn.schedule_bitserial_conv2d_nchw([C]) - A, W = args[:2] - return s, [A, W, C] - - @register("topi_nn_bitserial_dense") - def _topi_nn_bitserial_dense(*args, **kwargs): - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - A, W = args[:2] - C = topi.nn.bitserial_dense(*args, **kwargs) - s = topi.generic.schedule_bitserial_dense([C]) - return s, [A, W, C] - - @register("topi_nn_deformable_conv2d_nchw") - def _topi_nn_deformable_conv2d_nchw(*args, **kwargs): - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - A, Offset, W = args[:3] - C = topi.nn.deformable_conv2d_nchw(*args, **kwargs) - s = topi.generic.schedule_deformable_conv2d_nchw([C]) - return s, [A, Offset, W, C] - - @register("topi_nn_conv2d_NCHWc") - def _topi_nn_conv2d_NCHWc(*args, **kwargs): - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - A, W = args[:2] - C = topi.nn.conv2d_NCHWc(*args, **kwargs) - s = topi.generic.schedule_conv2d_NCHWc([C]) - return s, [A, W, C] + self.tracing = False - def reset(self, wanted_topi_funcs): + def reset(self, wanted_relay_ops=None): """Reset task collections Parameters ---------- - wanted_topi_funcs: List of function - The topi function to be extracted + wanted_relay_ops: List of relay.op.Op + The relay ops to be extracted """ self.task_collection = [] - self.wanted_topi_funcs = wanted_topi_funcs + self.wanted_relay_ops = wanted_relay_ops + + def add_task(self, task_name, args): + """Add AutoTVM task + + Parameters + ---------- + task_name: str + AutoTVM task name. + + args: tuple + Arguments to the TOPI function. + + cond: SpecializedCondition Review comment: removed ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services