merrymercy commented on a change in pull request #6269: URL: https://github.com/apache/incubator-tvm/pull/6269#discussion_r472674203
########## File path: python/tvm/auto_scheduler/search_policy.py ########## @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +The search policies for TVM Auto-scheduler. + +This contains the specific strategy on how to generate a schedule automatically. We provide a +default EmptyPolicy which always returns an unchanged initial state, and a practices proven +effective SketchPolicy which is able to deal with various ops/subgraphs on different target devices. + +Reference: +L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor +Programs for Deep Learning." arXiv preprint arXiv:2006.06762 (2020). +""" + +import random + +import tvm._ffi +from tvm.runtime import Object +from .cost_model import RandomModel +from . import _ffi_api + + +@tvm._ffi.register_object("auto_scheduler.SearchCallback") +class SearchCallback(Object): + """Callback function before or after search process""" + + +@tvm._ffi.register_object("auto_scheduler.PreloadMeasuredStates") +class PreloadMeasuredStates(SearchCallback): + """ A SearchCallback to load measured states from the log file for a search policy. + + This can resume the state of the search policy: + - Making sure an already measured state in former searches will never be measured again. + - The history states can be used to speed up the search process(e.g. SketchPolicy uses + history states as starting point to perform Evolutionary Search). + + Parameters + ---------- + filename : str + The measure record to load measured states from. Review comment: ```suggestion The name of the record file ``` ########## File path: python/tvm/auto_scheduler/search_policy.py ########## @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +The search policies for TVM Auto-scheduler. + +This contains the specific strategy on how to generate a schedule automatically. We provide a +default EmptyPolicy which always returns an unchanged initial state, and a practices proven +effective SketchPolicy which is able to deal with various ops/subgraphs on different target devices. Review comment: ```suggestion This contains the strategies to generate a schedule automatically. We provide an EmptyPolicy which always returns an unchanged initial state, and a more advanced SketchPolicy which can deal with various ops/subgraphs on different target devices. ``` ########## File path: python/tvm/auto_scheduler/search_policy.py ########## @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +The search policies for TVM Auto-scheduler. + +This contains the specific strategy on how to generate a schedule automatically. We provide a +default EmptyPolicy which always returns an unchanged initial state, and a practices proven +effective SketchPolicy which is able to deal with various ops/subgraphs on different target devices. + +Reference: +L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor +Programs for Deep Learning." arXiv preprint arXiv:2006.06762 (2020). +""" + +import random + +import tvm._ffi +from tvm.runtime import Object +from .cost_model import RandomModel +from . import _ffi_api + + +@tvm._ffi.register_object("auto_scheduler.SearchCallback") +class SearchCallback(Object): + """Callback function before or after search process""" + + +@tvm._ffi.register_object("auto_scheduler.PreloadMeasuredStates") +class PreloadMeasuredStates(SearchCallback): + """ A SearchCallback to load measured states from the log file for a search policy. + + This can resume the state of the search policy: + - Making sure an already measured state in former searches will never be measured again. + - The history states can be used to speed up the search process(e.g. SketchPolicy uses + history states as starting point to perform Evolutionary Search). + + Parameters + ---------- + filename : str + The measure record to load measured states from. + """ + def __init__(self, filename="auto_scheduler_tuning.json"): + self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename) + + +@tvm._ffi.register_object("auto_scheduler.SearchPolicy") +class SearchPolicy(Object): + """ The base class of search policies. """ + + +@tvm._ffi.register_object("auto_scheduler.EmptyPolicy") +class EmptyPolicy(SearchPolicy): + """ This is an example empty search policy which will always generate + the init state of ComputeDAG. + + Parameters + ---------- + task : SearchTask + The SearchTask for the computation declaration. + init_search_callbacks : Optional[List[SearchCallback]] + Callback functions called before the search process. + """ + def __init__(self, task, init_search_callbacks=None): + self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy, task, init_search_callbacks) + + +@tvm._ffi.register_object("auto_scheduler.SketchPolicy") +class SketchPolicy(SearchPolicy): + """ The search policy that searches in a hierarchical search space defined by sketches. + The policy randomly samples programs from the space defined by sketches and use evolutionary + search to fine-tune them. + + Parameters + ---------- + task : SearchTask + The SearchTask for the computation declaration. + schedule_cost_model : CostModel = RandomModel() + The cost model to estimate the complete schedules. + params : Optional[Dict[str, Any]] + Parameters of the search policy. + See `src/auto_scheduler/search_policy/sketch_search_policy.h` for the definitions. + See `DEFAULT_PARAMS` below to find the default values. + seed : Optional[int] + Random seed. + verbose : int = 1 + Verbosity level. 0 for silent, 1 to output information during schedule search. + init_search_callbacks : Optional[List[SearchCallback]] + Callback functions called before the search process, usually used to do extra + initializations. + Possible callbacks: + - auto_scheduler.PreloadMeasuredStates + - auto_scheduler.PreloadCustomSketchRule + TODO(jcf94): Add these search callback implementations. + """ + + DEFAULT_PARAMS = { + "eps_greedy": 0.05, + "retry_search_one_round_on_empty": 10, + + 'evolutionary_search_population': 2048, + "evolutionary_search_use_measured_ratio": 0.2, + + 'cpu_multi_level_tiling_structure': 'SSRSRS', + 'gpu_multi_level_tiling_structure': 'SSSRRSRS', + # Notice: the default thread bind policy of GPU assumes the tiling structure to have at + # least 3 spatial tiling levels in outermost + + 'max_innermost_split_factor': 16, + 'max_vectorize_size': 16, + + 'disable_change_compute_location': 0, + } + + def __init__(self, task, schedule_cost_model=RandomModel(), params=None, seed=None, verbose=1, + init_search_callbacks=None): + if params is None: + params = SketchPolicy.DEFAULT_PARAMS + else: + for key, value in SketchPolicy.DEFAULT_PARAMS.items(): + if key not in params: + params[key] = value + + self.__init_handle_by_constructor__( + _ffi_api.SketchPolicy, task, schedule_cost_model, params, + seed or random.randint(1, 1 << 30), verbose, init_search_callbacks) + + def generate_sketches(self, print_for_debug=False): + """ Generate the sketches, this is mainly used for debug. + + Parameters + ---------- + print_for_debug : bool = False + Whether print out the sketches for debug. + + Returns + ------- + sketches : List[State] + The generated sketches of this search task. + """ + sketches = _ffi_api.SketchPolicyGenerateSketches(self) + if print_for_debug: + for i, s in enumerate(sketches): + print("=" * 20 + " %d " % i + "=" * 20) + print(s) + return sketches + + def sample_initial_population(self, pop_size): + """Sample initial population. + This python interface is mainly used for debugging and testing. + The actual search is all doen in c++. Review comment: ```suggestion The actual search is all done in c++. ``` ########## File path: python/tvm/auto_scheduler/search_policy.py ########## @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +The search policies for TVM Auto-scheduler. + +This contains the specific strategy on how to generate a schedule automatically. We provide a +default EmptyPolicy which always returns an unchanged initial state, and a practices proven +effective SketchPolicy which is able to deal with various ops/subgraphs on different target devices. + +Reference: +L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating High-Performance Tensor +Programs for Deep Learning." arXiv preprint arXiv:2006.06762 (2020). +""" + +import random + +import tvm._ffi +from tvm.runtime import Object +from .cost_model import RandomModel +from . import _ffi_api + + +@tvm._ffi.register_object("auto_scheduler.SearchCallback") +class SearchCallback(Object): + """Callback function before or after search process""" + + +@tvm._ffi.register_object("auto_scheduler.PreloadMeasuredStates") +class PreloadMeasuredStates(SearchCallback): + """ A SearchCallback to load measured states from the log file for a search policy. + + This can resume the state of the search policy: + - Making sure an already measured state in former searches will never be measured again. + - The history states can be used to speed up the search process(e.g. SketchPolicy uses + history states as starting point to perform Evolutionary Search). + + Parameters + ---------- + filename : str + The measure record to load measured states from. + """ + def __init__(self, filename="auto_scheduler_tuning.json"): + self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename) + + +@tvm._ffi.register_object("auto_scheduler.SearchPolicy") +class SearchPolicy(Object): + """ The base class of search policies. """ + + +@tvm._ffi.register_object("auto_scheduler.EmptyPolicy") +class EmptyPolicy(SearchPolicy): + """ This is an example empty search policy which will always generate + the init state of ComputeDAG. + + Parameters + ---------- + task : SearchTask + The SearchTask for the computation declaration. + init_search_callbacks : Optional[List[SearchCallback]] + Callback functions called before the search process. + """ + def __init__(self, task, init_search_callbacks=None): + self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy, task, init_search_callbacks) + + +@tvm._ffi.register_object("auto_scheduler.SketchPolicy") +class SketchPolicy(SearchPolicy): + """ The search policy that searches in a hierarchical search space defined by sketches. + The policy randomly samples programs from the space defined by sketches and use evolutionary + search to fine-tune them. + + Parameters + ---------- + task : SearchTask + The SearchTask for the computation declaration. + schedule_cost_model : CostModel = RandomModel() + The cost model to estimate the complete schedules. + params : Optional[Dict[str, Any]] + Parameters of the search policy. + See `src/auto_scheduler/search_policy/sketch_search_policy.h` for the definitions. + See `DEFAULT_PARAMS` below to find the default values. + seed : Optional[int] + Random seed. + verbose : int = 1 + Verbosity level. 0 for silent, 1 to output information during schedule search. + init_search_callbacks : Optional[List[SearchCallback]] + Callback functions called before the search process, usually used to do extra + initializations. + Possible callbacks: + - auto_scheduler.PreloadMeasuredStates + - auto_scheduler.PreloadCustomSketchRule + TODO(jcf94): Add these search callback implementations. + """ + + DEFAULT_PARAMS = { + "eps_greedy": 0.05, + "retry_search_one_round_on_empty": 10, + + 'evolutionary_search_population': 2048, + "evolutionary_search_use_measured_ratio": 0.2, + + 'cpu_multi_level_tiling_structure': 'SSRSRS', + 'gpu_multi_level_tiling_structure': 'SSSRRSRS', + # Notice: the default thread bind policy of GPU assumes the tiling structure to have at + # least 3 spatial tiling levels in outermost + + 'max_innermost_split_factor': 16, + 'max_vectorize_size': 16, + + 'disable_change_compute_location': 0, + } + + def __init__(self, task, schedule_cost_model=RandomModel(), params=None, seed=None, verbose=1, + init_search_callbacks=None): + if params is None: + params = SketchPolicy.DEFAULT_PARAMS + else: + for key, value in SketchPolicy.DEFAULT_PARAMS.items(): + if key not in params: + params[key] = value + + self.__init_handle_by_constructor__( + _ffi_api.SketchPolicy, task, schedule_cost_model, params, + seed or random.randint(1, 1 << 30), verbose, init_search_callbacks) + + def generate_sketches(self, print_for_debug=False): + """ Generate the sketches, this is mainly used for debug. Review comment: ```suggestion """ Generate the sketches. This python interface is mainly used for debugging and testing. The actual search is all done in c++. ``` ########## File path: src/auto_scheduler/search_task.cc ########## @@ -44,8 +45,33 @@ HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_l HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target, const Target& target_host) { - if (target->kind->name == "llvm") { + if (target->kind->device_type == kDLCPU) { return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64); + } else if (target->kind->device_type == kDLGPU) { + auto hardware_params = HardwareParams(-1, 16, 64); + auto* p_hardware_params = hardware_params.CopyOnWrite(); + + auto ctx = TVMContext{kDLGPU, 0}; + auto func = tvm::runtime::Registry::Get("device_api.gpu"); + CHECK(func != nullptr) << "Cannot find GPU device_api in registry"; + auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*()); + + tvm::runtime::TVMRetValue ret; + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); + p_hardware_params->max_shared_memory_per_block = ret; + + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, &ret); + p_hardware_params->max_registers_per_block = ret; + + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret); + p_hardware_params->max_threads_per_block = ret; + + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); + p_hardware_params->warp_size = ret; + + p_hardware_params->max_vthread_extent = 4; Review comment: In our internal repo, we use ` p_hardware_params->warp_size / 4 = 8`. I found some schedules in tophub has values higher than 4. ########## File path: tests/python/unittest/test_auto_scheduler_sketch_generation.py ########## @@ -93,10 +101,80 @@ def test_cpu_conv2d_winograd_sketch(): ''' assert len(sketches) == 3 + +def test_cuda_matmul_sketch(): + if not tvm.context("cuda", 0).exist: + return + sketches = generate_sketches(matmul_auto_scheduler_test, (512, 512, 512), 'cuda') + ''' 1 multi-level tiling sketch ''' Review comment: Do we need to do more fine-grained checks? Only checking the number is not very robust. ########## File path: src/auto_scheduler/search_policy/utils.h ########## @@ -45,6 +45,27 @@ namespace tvm { namespace auto_scheduler { +/*! \brief Return whether the search task is targeting a CPU. */ +inline bool IsCPUTask(const SearchTask& task) { + return ((task)->target->kind->device_type == kDLCPU); +} + +/*! \brief Return whether the search task is targeting a GPU. */ +inline bool IsGPUTask(const SearchTask& task) { + return ((task)->target->kind->device_type == kDLGPU || + (task)->target->kind->device_type == kDLOpenCL); Review comment: This function seems to cover more cases https://github.com/apache/incubator-tvm/blob/e5b793f39fd5b4f84b0aedf06aa376ebe45cf2bc/src/tir/analysis/verify_memory.cc#L151-L154 ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
