merrymercy commented on a change in pull request #5962: URL: https://github.com/apache/incubator-tvm/pull/5962#discussion_r449744269
########## File path: src/ansor/search_policy/search_policy.h ########## @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/search_policy/search_policy.h + * \brief The base class of search policies, including the abstract definition of search policy and + * other supporting data structures. + * + * The basic schedule search process for Ansor is design to be: + * `Program sampling` -> `Performance Tuning`. + * + * In `Program sampling`, we use some predefined precise or heuristic rules to generate several + * initial schedules. Based on these initial starting points, we perform `Performance Tuning` which + * uses cost model based evolutionary search to select schedules with the best performance. + * + * Candidate schedules are measured against the specific hardware target. + * + * \note Adding a new search policy. + * In design, there's no need for users to implement their own search policy, our formal search + * policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule + * mechanism will be provided to enable user-defined template search to serve the same functionality + * as the current AutoTVM template. + * + * This guide is to help understand it better and incase some advanced users have special + * requirements. + * 1. The only funcion that must be implemented is Search(), the design principe for it is to be + * the entry of starting a schedule search process and returns the best schedule get. + * 2. Information about the compute declaration of ops/subgraphs can be acquired from SearchTask. + * This structure also contains some information about the target device. (e.g. knowing the weight + * of the device vector unit, we can limit the max vectorize size during schedule generating) + * 3. SearchCallback provides more flexibility to do extra affairs during the search process. + * 4. ProgramMeasurer provides a simple but useful api to help check the performance of states get + * during the search process. + */ Review comment: ```suggestion This guide is for advanced uses who have special requirements. * 1. The only function that must be implemented is Search(), which takes a task as input and returns the best states found. * 2. Information about the compute declaration of ops/subgraphs can be acquired from SearchTask. * This structure also contains some information about the target device. (e.g. knowing the width * of the device vector unit, we can limit the max vectorize size during schedule search) * 3. SearchCallback provides more flexibility to do extra affairs before/after the search process. * 4. ProgramMeasurer provides a simple but useful api to help check the performance of states got * during the search process. */ ``` ########## File path: src/ansor/search_task.h ########## @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/search_task.h + * \brief Meta information and hardware parameters for a search task. + */ + +#ifndef TVM_ANSOR_SEARCH_TASK_H_ +#define TVM_ANSOR_SEARCH_TASK_H_ + +#include <tvm/target/target.h> + +#include "compute_dag.h" + +namespace tvm { +namespace ansor { + +class HardwareParams; + +/*! \brief The parameters of target hardware used to guide the search process of SearchPolicy. */ +class HardwareParamsNode : public Object { + public: + /*! \brief The number of cores. */ + int num_cores; + /*! \brief The width of vector units in bytes. */ + int vector_unit_bytes; + /*! \brief The size of cache line in bytes. */ + int cache_line_bytes; + + // Some GPU related limitations + // Get from TVM device api + + /*! \brief The max shared memory per block. */ + int max_shared_memory_per_block{INT32_MAX}; + /*! \brief The max register memory per block. */ + int max_registers_per_block{INT32_MAX}; + /*! \brief The max threads per block. */ + int max_threads_per_block{INT32_MAX}; + /*! \brief The max vthread extent. */ + int max_vthread_extent{INT32_MAX}; + /*! \brief The thread numbers of a warp. */ + int warp_size{INT32_MAX}; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("num_cores", &num_cores); + v->Visit("vector_unit_bytes", &vector_unit_bytes); + v->Visit("cache_line_bytes", &cache_line_bytes); + v->Visit("max_shared_memory_per_block", &max_shared_memory_per_block); + v->Visit("max_registers_per_block", &max_registers_per_block); + v->Visit("max_threads_per_block", &max_threads_per_block); + v->Visit("max_vthread_extent", &max_vthread_extent); + v->Visit("warp_size", &warp_size); + } + + /*! + * \brief Get the default hardware params. + * \param target A `tvm.target`. + * \param target_host A `tvm.target` for host device. + * \return A HardwareParams object. + */ + static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host); + + static constexpr const char* _type_key = "ansor.HardwareParams"; + TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object); +}; + +/*! + * \brief Managed reference to HardwareParamsNode. + * \sa HardwareParamsNode + */ +class HardwareParams : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param num_cores The number of cores. + * \param vector_unit_bytes The width of vector units in bytes. + * \param cache_line_bytes The size of cache line in bytes. + */ + HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes); + + TVM_DEFINE_OBJECT_REF_METHODS(HardwareParams, ObjectRef, HardwareParamsNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(HardwareParamsNode); +}; + +/*! + * \brief The computation information and hardware parameters for a specific schedule search task. + */ +class SearchTaskNode : public Object { + public: + /*! \brief The ComputeDAG for target compute declaration. */ Review comment: replace all "target compute declaration" with "input compute declaration" in this file ########## File path: src/ansor/search_policy/search_policy.h ########## @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/search_policy/search_policy.h + * \brief The base class of search policies, including the abstract definition of search policy and + * other supporting data structures. + * + * The basic schedule search process for Ansor is design to be: + * `Program sampling` -> `Performance Tuning`. + * + * In `Program sampling`, we use some predefined precise or heuristic rules to generate several + * initial schedules. Based on these initial starting points, we perform `Performance Tuning` which + * uses cost model based evolutionary search to select schedules with the best performance. + * + * Candidate schedules are measured against the specific hardware target. + * + * \note Adding a new search policy. + * In design, there's no need for users to implement their own search policy, our formal search + * policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule + * mechanism will be provided to enable user-defined template search to serve the same functionality + * as the current AutoTVM template. + * + * This guide is to help understand it better and incase some advanced users have special + * requirements. + * 1. The only funcion that must be implemented is Search(), the design principe for it is to be + * the entry of starting a schedule search process and returns the best schedule get. + * 2. Information about the compute declaration of ops/subgraphs can be acquired from SearchTask. + * This structure also contains some information about the target device. (e.g. knowing the weight + * of the device vector unit, we can limit the max vectorize size during schedule generating) + * 3. SearchCallback provides more flexibility to do extra affairs during the search process. + * 4. ProgramMeasurer provides a simple but useful api to help check the performance of states get + * during the search process. + */ + +#ifndef TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ +#define TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ + +#include <tvm/node/node.h> + +#include <unordered_set> +#include <vector> + +#include "../search_task.h" + +namespace tvm { +namespace ansor { + +class ProgramMeasurer; +class SearchPolicyNode; + +/*! + * \brief Callback function to be called by the search process. + * This interface allows to do extra initializations before schedule search or extra + * check during/after the schedule search. + */ +class SearchCallbackNode : public Object { + public: + /*! + * \brief Run the registered callback function. + * \param policy A pointer to a SearchPolicyNode. + */ + virtual void Callback(SearchPolicyNode* policy) = 0; + + static constexpr const char* _type_key = "ansor.SearchCallback"; + TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object); +}; + +/*! + * \brief Managed reference to SearchCallbackNode. + * \sa SearchCallbackNode + */ +class SearchCallback : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchCallback, ObjectRef, SearchCallbackNode); +}; + +/*! + * \brief The base class of search policies. + */ +class SearchPolicyNode : public Object { + public: + /*! \brief The current search task. */ + SearchTask cur_task; + /*! + * \brief Verbose level to control the screen output during schedule search. + * 0 for silent, 1 to output information. + */ + int verbose; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("cur_task", &cur_task); + v->Visit("verbose", &verbose); + } + + /*! + * \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state + * get during the search process. + * \param task The target search task. Review comment: replace `target search task` with `search task` in all files ########## File path: src/ansor/search_policy/empty_policy.h ########## @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/search_policy/empty_policy.h + * \brief This is an brief example of search policy. Review comment: ```suggestion * \brief A brief example of the search policy which always returns the initial naive schedule (state) ``` ########## File path: python/tvm/ansor/workload_registry.py ########## @@ -0,0 +1,170 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Workload registration and serialization. + +We use a json string to represent a workload (a compute dag). +The format of the string is `[func_name, [args...]]`. +The dag should be the return value of this `func_name(*args)`. + +Rationale: The workload is actually a compute dag defined by tvm dsl. But serializing compute dags +and matching them efficiently is not easy. Therefore, we use the above string to encode a compute +dag. +These strings are efficient for serialization/matching and wont' be too long. +When we need the dag, we decode the string and call the function, which will return the dag. +""" + +import pickle +import json + +import tvm._ffi +from .utils import serialize_args, deserialize_args + +WORKLOAD_FUNC_REGISTRY = {} + + +def register_workload(func): + """ Register a workload by generation function. + + The input function should take hashable and jsonable arguments + (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor. + + Parameters + ---------- + func : Function + The target function that returns the compute declaration Tensors. Review comment: Remove all "target" in this file. Replace it with "task" if you really want something before the noun. ########## File path: python/tvm/ansor/measure.py ########## @@ -0,0 +1,386 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Distributed measurement infrastructure to measure the runtime costs of tensor programs + +These functions are responsible for building the tvm module, uploading it to +remote devices, recording the running time costs, and checking the correctness of the output. + +We implement these in python to utilize python's multiprocessing and error handling +""" + +import os +import time +import shutil +import traceback +import tempfile +import multiprocessing + +import tvm._ffi +from tvm.runtime import Object, module, ndarray +from tvm.driver import build_module +from tvm.ir import transform +from tvm.contrib import tar, ndk + +from . import _ffi_api +from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout + +# The maximum length of error message +MAX_ERROR_MSG_LEN = 512 + +# Global variables used in build function +GLOBAL_BUILD_ARGUMENTS = None + +@tvm._ffi.register_object("ansor.MeasureCallback") +class MeasureCallback(Object): + """ The base class of measurement callback functions. """ + + +@tvm._ffi.register_object("ansor.MeasureInput") +class MeasureInput(Object): + """ Store the input of a measurement. + + Parameters + ---------- + task : SearchTask + The target SearchTask. Review comment: Remove all "target" before "SearchTask", "search task", "compute declaration". Replace it with "input" if you really want something before the noun. ########## File path: src/ansor/search_task.h ########## @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/search_task.h + * \brief Meta information and hardware parameters for a search task. + */ + +#ifndef TVM_ANSOR_SEARCH_TASK_H_ +#define TVM_ANSOR_SEARCH_TASK_H_ + +#include <tvm/target/target.h> + +#include "compute_dag.h" + +namespace tvm { +namespace ansor { + +class HardwareParams; + +/*! \brief The parameters of target hardware used to guide the search process of SearchPolicy. */ +class HardwareParamsNode : public Object { + public: + /*! \brief The number of cores. */ + int num_cores; + /*! \brief The width of vector units in bytes. */ + int vector_unit_bytes; + /*! \brief The size of cache line in bytes. */ + int cache_line_bytes; + + // Some GPU related limitations + // Get from TVM device api Review comment: ```suggestion // GPU related parameters got from device query API ``` ########## File path: src/ansor/measure.h ########## @@ -0,0 +1,430 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/measure.h + * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs. + * MeasureInput -> BuildeResult -> MeasureResult Review comment: ```suggestion * The flow of data structures is MeasureInput -> BuildeResult -> MeasureResult. ``` ########## File path: python/tvm/ansor/serialization.py ########## @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Serialization and other I/O support for tuning logs (measurement records)""" + +import numpy as np + +import tvm._ffi +from tvm.runtime import Object +from .measure import MeasureCallback, MeasureErrorNo +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.LogToFile") +class LogToFile(MeasureCallback): + """ + A measurement callback that writes measurement records into a file. + + Parameters + ---------- + filename : str + File name for this callback to write log to. + """ + def __init__(self, filename="ansor_tuning.json"): + self.__init_handle_by_constructor__(_ffi_api.LogToFile, filename) + + +@tvm._ffi.register_object("ansor.LogReader") +class LogReader(Object): + """ + Reader of the json log file. + + Parameters + ---------- + filename : str = "ansor_tuning.json" + File name for this reader to load log from. + """ + def __init__(self, filename="ansor_tuning.json"): + self.__init_handle_by_constructor__(_ffi_api.LogReader, filename) + + def read_lines(self, max_lines=-1, skip_lines=0): + """ Read multiple lines from the log file. + + Parameters + ---------- + max_lines : int = -1 + The maximum number of lines. -1 means to read all lines. + skip_lines : int = 0 + Skip the first n lines. + + Returns + ------- + inputs : List[MeasureInput] + The MeasureInputs loaded from the log file. + results : List[MeasureResult] + The MeasureResults loaded from the log file. + """ + inputs, results = _ffi_api.LogReaderReadLines(self, max_lines, skip_lines) + return inputs, results + + def __iter__(self): + while True: + ret = _ffi_api.LogReaderReadNext(self) + if not ret: + break + yield ret[0], ret[1] # (input, result) + + +def load_from_file(filename): + """ + Load measurement records from a file. + + Parameters + ---------- + filename : str + File name to load log from. + + Returns + ------- + logs : List[MeasureInput, MeasureResult] + """ + return zip(*LogReader(filename).read_lines()) + + +def append_measure_records_to_file(filename, inputs, results): + """ + Aappend measure records to file. + + Parameters + ---------- + filename : str + File name to write log to. + inputs: List[MeasureInputs] + The target MeasureInputs to be written. Review comment: Remove all "target" before "MeasureInputs", "MeasureResults" or "compute declaration" in this file ########## File path: src/ansor/loop_state.h ########## @@ -0,0 +1,375 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/loop_state.h + * \brief The definition of the "state" in search. + * + * Each LoopState corresponds to a specific schedule for its target ComputeDAG. + * A LoopState consists of: 1. a current loop structure; 2. a history of transformations used to + * construct it. + * The loop structure keeps a preview of how the schedule will finally look like after lowering the + * current state (e.g. number of iterators, the extent of each iterator, the compute_at locations + * ...). During the schedule search process, the loop structure can provide search policy with + * necessary information on how to perform further operations with the current state. + * The transform history is a sequence of TransformStep which will finally be mapped to schedule + * primitives. The steps can also be used for serialization of a state. + * + * The LoopState can be seen as a lightweight loop structure IR specifically for schedule search. + * We don't use the existing TVM IR but to extend a new structure on it is because: + * 1. We want fast incremental change to the loop structures, search policy needs to get the + * immediate loop structures update rather than after TVM lowering; + * 2. We want serializable transform history for replay, backtracking, and mutation; + * 3. We may create some macro schedule primitives that represent the combination of several TVM + * schedule primitives. + * + * When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives. + * Since we share a lot of common objects during search, the transformation is implemented in copy + * on write style. All objects are immutable, which is similar to TVM IR. + */ + +#ifndef TVM_ANSOR_LOOP_STATE_H_ +#define TVM_ANSOR_LOOP_STATE_H_ + +#include <tvm/runtime/container.h> + +#include <functional> + +#include "transform_step.h" + +namespace tvm { +namespace ansor { + +using namespace tvm::tir; + +class ComputeDAG; + +/*! \brief The type of a stage. */ +enum StageType { + /*! \brief A placeholder stage. */ + kPlaceholder = 0, + /*! \brief A compute stage. */ + kCompute = 1 +}; + +/*! \brief The type of compute location. */ +enum ComputeAtType { + /*! \brief Compute at root. */ + kRoot = 0, + /*! \brief Compute inlined. */ + kInlined = 1, + /*! \brief Compute at some iterator. */ + kIter = 2, +}; + +/*! \brief The type of an iterator. */ +enum IteratorType { + /*! \brief Spatial iterator. */ + kSpace = 0, + /*! \brief Reduction iterator. */ + kReduce = 1, + /*! \brief Fused spatial and reduction iterator. */ + kMixed = 2, + /*! \brief Special iterator. (e.g. virtual root iterator) */ + kSpecial = 3 +}; + +/*! \brief The type of an iterator's annotation. */ +enum IteratorAnnotation { + /*! \brief This iterator has no annotation. */ + kNone = 0, + /*! \brief This iterator has been unrolled. */ + kUnroll = 1, + /*! \brief This iterator has been vectorized. */ + kVectorize = 2, + /*! \brief This iterator has been paralleld. */ + kParallel = 3, + /*! \brief This iterator has been bind to vthread. */ + kVThread = 4, + /*! \brief This iterator has been bind to blockIdx.x. */ + kBlockX = 5, + /*! \brief This iterator has been bind to threadIdx.x. */ + kThreadX = 6, + /*! \brief This iterator has been bind to blockIdx.y. */ + kBlockY = 7, + /*! \brief This iterator has been bind to threadIdx.y. */ + kThreadY = 8, + /*! \brief This iterator has been mapped with a tensorize intrinsic. */ + kTensorized = 9 +}; + +/*! + * \brief A for loop iterator + * Similar to tvm::IterVar in `include/tvm/tir/expr.h` + */ +class IteratorNode : public Object { + public: + /*! \brief The name of this iterator. */ + String name; + /*! \brief The target range of this iterator. */ Review comment: What's the meaning of "target range"? Remove all "target" in this file. ########## File path: src/ansor/search_policy/search_policy.h ########## @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/search_policy/search_policy.h + * \brief The base class of search policies, including the abstract definition of search policy and + * other supporting data structures. + * + * The basic schedule search process for Ansor is design to be: + * `Program sampling` -> `Performance Tuning`. + * + * In `Program sampling`, we use some predefined precise or heuristic rules to generate several + * initial schedules. Based on these initial starting points, we perform `Performance Tuning` which + * uses cost model based evolutionary search to select schedules with the best performance. + * + * Candidate schedules are measured against the specific hardware target. + * + * \note Adding a new search policy. + * In design, there's no need for users to implement their own search policy, our formal search + * policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule + * mechanism will be provided to enable user-defined template search to serve the same functionality + * as the current AutoTVM template. + * + * This guide is to help understand it better and incase some advanced users have special + * requirements. + * 1. The only funcion that must be implemented is Search(), the design principe for it is to be + * the entry of starting a schedule search process and returns the best schedule get. + * 2. Information about the compute declaration of ops/subgraphs can be acquired from SearchTask. + * This structure also contains some information about the target device. (e.g. knowing the weight + * of the device vector unit, we can limit the max vectorize size during schedule generating) + * 3. SearchCallback provides more flexibility to do extra affairs during the search process. + * 4. ProgramMeasurer provides a simple but useful api to help check the performance of states get + * during the search process. + */ + +#ifndef TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ +#define TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ + +#include <tvm/node/node.h> + +#include <unordered_set> +#include <vector> + +#include "../search_task.h" + +namespace tvm { +namespace ansor { + +class ProgramMeasurer; +class SearchPolicyNode; + +/*! + * \brief Callback function to be called by the search process. + * This interface allows to do extra initializations before schedule search or extra + * check during/after the schedule search. + */ +class SearchCallbackNode : public Object { + public: + /*! + * \brief Run the registered callback function. + * \param policy A pointer to a SearchPolicyNode. + */ + virtual void Callback(SearchPolicyNode* policy) = 0; + + static constexpr const char* _type_key = "ansor.SearchCallback"; + TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object); +}; + +/*! + * \brief Managed reference to SearchCallbackNode. + * \sa SearchCallbackNode + */ +class SearchCallback : public ObjectRef { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchCallback, ObjectRef, SearchCallbackNode); +}; + +/*! + * \brief The base class of search policies. + */ +class SearchPolicyNode : public Object { + public: + /*! \brief The current search task. */ + SearchTask cur_task; + /*! + * \brief Verbose level to control the screen output during schedule search. + * 0 for silent, 1 to output information. + */ + int verbose; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("cur_task", &cur_task); + v->Visit("verbose", &verbose); + } + + /*! + * \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state + * get during the search process. + * \param task The target search task. + * \param num_measure_trials Total schedules to be tried during this search. + * \param early_stopping Early stop if no better schedule is found. + * \param num_measures_per_round Max measure batch in one search round. + * \param verbose Verbose level. 0 for silent, 1 to output information during schedule search. + * \param measurer A ProgramMeasurer which packs ProgramBuilder & ProgramRunner inside. + * \param pre_search_callbacks SearchCallback to be called before schedule search. + * \return The best state get. + */ + virtual State Search(SearchTask task, int num_measure_trials, int early_stopping, + int num_measures_per_round, int verbose, ProgramMeasurer measurer, + Array<SearchCallback> pre_search_callbacks) = 0; + + /*! + * \brief Call SearchCallback with the current SearchPolicyNode + * \param callbacks SearchCallback to be called. + */ + void RunCallbacks(const Array<SearchCallback>& callbacks); + + static constexpr const char* _type_key = "ansor.SearchPolicy"; + TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); + + protected: + /*! + * \brief The set of already measured states. + * We store the string format for redundancy check. + */ + std::unordered_set<String> measured_states_set_; Review comment: For this internal `std::unordered_set`, I do not think we need to use tvm's String ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
