zhiics commented on a change in pull request #5962: URL: https://github.com/apache/incubator-tvm/pull/5962#discussion_r448760203
########## File path: python/tvm/ansor/compute_dag.py ########## @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Computational graph and its analysis tools """ + +import hashlib + +import tvm._ffi +from tvm.runtime import Object +from tvm.te import PlaceholderOp, ComputeOp + +from .loop_state import State, StateObject +from .utils import get_const_tuple +from .workload_registry import workload_key_to_tensors + +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.ComputeDAG") +class ComputeDAG(Object): + """ + Computation declaration graph. + + Parameters + ---------- + compute : Union[List[Tensor], str] + `Tensor`s or workload key for a compute declaration. + """ + def __init__(self, compute): + if isinstance(compute, str): + compute = workload_key_to_tensors(compute) + elif isinstance(compute, list): + for item in compute: + if not isinstance(item, tvm.te.Tensor): + raise ValueError("The input of ComputeDAG should be a list of Tensor") + else: + raise ValueError("Invalid compute: " + compute + ". Expect a string or list of Tensor") + self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute) + + def get_init_state(self): + """ Get init state of this ComputeDAG. + + Returns + ------- + state : State + The initial State without any transform steps. + """ + return State(_ffi_api.ComputeDAGGetInitState(self), self) + + def apply_steps_from_state(self, state): + """ + Apply transform steps according to the history of a State. + + Parameters + ---------- + state : Union[State, StateObject] + The target state to be applied to TVM schedule. + + Returns + ------- + A `te.schedule` and the target `te.Tensor`s to be used in `tvm.lower` or `tvm.build` + """ + state_obj = state if isinstance(state, StateObject) else state.state_object + return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj) + + def print_python_code_from_state(self, state): Review comment: we should consider a better naming: pretty_print/print_state? ########## File path: python/tvm/ansor/auto_schedule.py ########## @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +User interface for Ansor auto-scheduler. + +The basic schedule search process for Ansor is design to be: Review comment: designed ########## File path: python/tvm/ansor/auto_schedule.py ########## @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +User interface for Ansor auto-scheduler. + +The basic schedule search process for Ansor is design to be: +`Program sampling` -> `Performance Tuning`. + +In `Program sampling`, we use some predefined or heuristic rules to generate several initial +schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model +and evolutionary search to seek for schedules with the best performance. Candidate schedules will +be measured in the target hardware. +""" + +import tvm._ffi +from tvm.runtime import Object +from .compute_dag import ComputeDAG +from .measure import LocalBuilder, LocalRunner +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.HardwareParams") +class HardwareParams(Object): + """ The parameters of target hardware, this is used to guide the search process of + SearchPolicy. + + TODO(...): This is considering to merge with the new Target: + https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844 + + Parameters + ---------- + num_cores : int + The number of device cores. + vector_unit_bytes : int + The width of vector units in bytes. + cache_line_bytes : int + The size of cache line in bytes. + max_unroll_vec : int + The max length of an axis to be unrolled or vectorized. + max_innermost_split_factor : int + The max split factor for the innermost tile. + """ + def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes, + max_unroll_vec, max_innermost_split_factor): + self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores, + vector_unit_bytes, cache_line_bytes, + max_unroll_vec, max_innermost_split_factor) + + +@tvm._ffi.register_object("ansor.SearchTask") +class SearchTask(Object): + """ The meta-information of a search task. + + Parameters + ---------- + dag : ComputeDAG + The ComputeDAG for target compute declaration. + workload_key : str + The workload key for target compute declaration. + target : tvm.target.Target + The target device of this search task. + target_host : Optional[tvm.target.Target] + The target host device of this search task. + hardware_params : Optional[HardwareParams] + Hardware parameters used in this search task. + """ + def __init__(self, dag, workload_key, target, target_host=None, + hardware_params=None): + self.__init_handle_by_constructor__(_ffi_api.SearchTask, dag, + workload_key, target, target_host, + hardware_params) + + +@tvm._ffi.register_object("ansor.SearchPolicy") +class SearchPolicy(Object): + """ The base class for search policy """ + + +@tvm._ffi.register_object("ansor.EmptyPolicy") +class EmptyPolicy(SearchPolicy): + """ This is an example empty search policy which will always generate + the init state of target ComputeDAG. + """ + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy) + + +@tvm._ffi.register_object("ansor.TuneOption") +class TuneOption(Object): + """ This controls the options of performance tuning. + + Parameters + ---------- + num_measure_trials: int = 0 + The number of total schedule measure trials. + Ansor takes `num_measure_trials` state for measuring in total, and finally gets the best + schedule among them. + With `num_measure_trials` == 0, Ansor will do the schedule search but don't involve + measurement, this can be used if we want to quickly get a runnable schedule without + performance tuning. + early_stopping: int = -1 + Stops early the tuning if no improvement get after n measurements. + num_measures_per_round: int = 64 + The number of programs to be measured at each search round. + The whole schedule search process is designed to have several rounds to try a total + `num_measure_trials` schedules. + We have: `num_search_rounds` = `num_measure_trials` // `num_measures_per_round` + verbose: int = 1 + Verbosity level. 0 for silent, 1 to output information during schedule search. + builder: Union[Builder, str] = 'local' + Builder which builds the program. + runner: Union[Runner, str] = 'local' + Runner which runs the program and measures time costs. + measure_callbacks: Optional[List[MeasureCallback]] + Callback functions called after each measure. + Candidates: + - ansor.LogToFile + pre_search_callbacks: Optional[List[SearchCallback]] + Callback functions called before the search process. + Candidates: + - ansor.PreloadMeasuredStates + - ansor.PreloadCustomSketchRule + TODO(jcf94): Add these implementation in later PRs. + """ + def __init__(self, num_measure_trials=0, early_stopping=-1, num_measures_per_round=64, Review comment: early_stopping -> early_termination IMHO, this API looks a bit bulky to me, should we have some config dict to do this? ########## File path: python/tvm/ansor/workload_registry.py ########## @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Workload registration and serialization. + +We use a json string to represent a workload (a compute dag). +The format of the string is `[func_name, [args...]]`. +The dag should be the return value of this `func_name(*args)`. + +Rationale: The workload is actually a compute dag defined by tvm dsl. But serializing compute dags +and matching them efficiently is not easy. Therefore, we use the above string to encode a compute +dag. +These strings are efficient for serialization/matching and wont' be too long. +When we need the dag, we decode the string and call the function, which will return the dag. +""" + +import pickle +import json + +import tvm._ffi +from .utils import serialize_args, deserialize_args + +WORKLOAD_FUNC_REGISTRY = {} + + +def register_workload_by_func(func): + """ Register a workload by generation function. + + The input function should take hashable and jsonable arguments + (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor. + + Examples + -------- + @ansor.register_workload_by_func + def matmul(N, M, K): + A = te.placeholder((N, K), name='A') + B = te.placeholder((K, M), name='B') + k = te.reduce_axis((0, K), name='k') + C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C') + return [A, B, C] + """ + func_name = func.__name__ + if func_name in WORKLOAD_FUNC_REGISTRY: + raise RuntimeError('%s has been registered already' % func_name) + WORKLOAD_FUNC_REGISTRY[func_name] = func + return func + + +def make_workload_key_by_func(func, args): Review comment: ditto ########## File path: python/tvm/ansor/workload_registry.py ########## @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Workload registration and serialization. + +We use a json string to represent a workload (a compute dag). +The format of the string is `[func_name, [args...]]`. +The dag should be the return value of this `func_name(*args)`. + +Rationale: The workload is actually a compute dag defined by tvm dsl. But serializing compute dags +and matching them efficiently is not easy. Therefore, we use the above string to encode a compute +dag. +These strings are efficient for serialization/matching and wont' be too long. +When we need the dag, we decode the string and call the function, which will return the dag. +""" + +import pickle +import json + +import tvm._ffi +from .utils import serialize_args, deserialize_args + +WORKLOAD_FUNC_REGISTRY = {} + + +def register_workload_by_func(func): Review comment: just `register_workload`? your input is func which indicates by func and you can annotate func or assert func is Function type ########## File path: python/tvm/ansor/measure.py ########## @@ -0,0 +1,386 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Distributed measurement infrastructure to measure the runtime costs of tensor programs + +These functions are responsible for building the tvm module, uploading it to +remote devices, recording the running time costs, and checking the correctness of the output. + +We implement these in python to utilize python's multiprocessing and error handling +""" + +import os +import time +import shutil +import traceback +import tempfile +import multiprocessing + +import tvm._ffi +from tvm.runtime import Object, module, ndarray +from tvm.driver import build_module +from tvm.ir import transform +from tvm.contrib import tar, ndk + +from . import _ffi_api +from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout + +# The maximum length of error message +MAX_ERROR_MSG_LEN = 512 + +# Global variables used in build function +GLOBAL_BUILD_ARGUMENTS = None + +@tvm._ffi.register_object("ansor.MeasureCallback") +class MeasureCallback(Object): + """ Base class for measurement callback function. """ + + +@tvm._ffi.register_object("ansor.MeasureInput") +class MeasureInput(Object): + """ Store the input of a measurement. + + Parameters + ---------- + task : SearchTask + The target SearchTask. + state : State + The current State to be measured. + """ + def __init__(self, task, state): + self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object) + + +@tvm._ffi.register_object("ansor.BuildResult") +class BuildResult(Object): Review comment: `BuildResult` sounds a bit vague to me ########## File path: python/tvm/ansor/serialization.py ########## @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Serialization and other I/O support for tuning logs (measurement records)""" + +import numpy as np + +import tvm._ffi +from tvm.runtime import Object +from .measure import MeasureCallback, MeasureErrorNo +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.LogToFile") +class LogToFile(MeasureCallback): + """ + A measurement callback that writes measurement records into a file. + + Parameters + ---------- + filename : str + File name for this callback to write log to. + """ + def __init__(self, filename="ansor_tuning.json"): + self.__init_handle_by_constructor__(_ffi_api.LogToFile, filename) + + +@tvm._ffi.register_object("ansor.LogReader") +class LogReader(Object): + """ + Reader of the json log file. + + Parameters + ---------- + filename : str = "ansor_tuning.json" + File name for this reader to load log from. + """ + def __init__(self, filename="ansor_tuning.json"): + self.__init_handle_by_constructor__(_ffi_api.LogReader, filename) + + def read_lines(self, max_lines=-1, skip_lines=0): + """ Read multiple lines from the log file. + + Parameters + ---------- + max_lines : int = -1 + The maximum number of lines. -1 means to read all lines. + skip_lines : int = 0 + Skip the first n lines. + + Returns + ------- + inputs : List[MeasureInput] + The MeasureInputs loaded from the log file. + results : List[MeasureResult] + The MeasureResults loaded from the log file. + """ + inputs, results = _ffi_api.LogReaderReadLines(self, max_lines, skip_lines) + return inputs, results + + def __iter__(self): + while True: + ret = _ffi_api.LogReaderReadNext(self) + if not ret: + break + yield ret[0], ret[1] # (input, result) + + +def load_from_file(filename): + """ + Load measurement records from a file. + + Parameters + ---------- + filename : str + File name to load log from. + + Returns + ------- + logs : List[MeasureInput, MeasureResult] + """ + return zip(*LogReader(filename).read_lines()) + + +def append_measure_records_to_file(filename, inputs, results): Review comment: This API looks too long, append_records/redirect_records/save_records? ########## File path: python/tvm/ansor/compute_dag.py ########## @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Computational graph and its analysis tools """ + +import hashlib + +import tvm._ffi +from tvm.runtime import Object +from tvm.te import PlaceholderOp, ComputeOp + +from .loop_state import State, StateObject +from .utils import get_const_tuple +from .workload_registry import workload_key_to_tensors + +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.ComputeDAG") +class ComputeDAG(Object): + """ + Computation declaration graph. + + Parameters + ---------- + compute : Union[List[Tensor], str] + `Tensor`s or workload key for a compute declaration. + """ + def __init__(self, compute): + if isinstance(compute, str): + compute = workload_key_to_tensors(compute) + elif isinstance(compute, list): + for item in compute: + if not isinstance(item, tvm.te.Tensor): + raise ValueError("The input of ComputeDAG should be a list of Tensor") + else: + raise ValueError("Invalid compute: " + compute + ". Expect a string or list of Tensor") + self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute) + + def get_init_state(self): + """ Get init state of this ComputeDAG. + + Returns + ------- + state : State + The initial State without any transform steps. + """ + return State(_ffi_api.ComputeDAGGetInitState(self), self) + + def apply_steps_from_state(self, state): Review comment: I think this API name is not very good as well ########## File path: python/tvm/ansor/workload_registry.py ########## @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Workload registration and serialization. + +We use a json string to represent a workload (a compute dag). +The format of the string is `[func_name, [args...]]`. +The dag should be the return value of this `func_name(*args)`. + +Rationale: The workload is actually a compute dag defined by tvm dsl. But serializing compute dags +and matching them efficiently is not easy. Therefore, we use the above string to encode a compute +dag. +These strings are efficient for serialization/matching and wont' be too long. +When we need the dag, we decode the string and call the function, which will return the dag. +""" + +import pickle +import json + +import tvm._ffi +from .utils import serialize_args, deserialize_args + +WORKLOAD_FUNC_REGISTRY = {} + + +def register_workload_by_func(func): + """ Register a workload by generation function. + + The input function should take hashable and jsonable arguments + (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor. + + Examples + -------- + @ansor.register_workload_by_func + def matmul(N, M, K): + A = te.placeholder((N, K), name='A') + B = te.placeholder((K, M), name='B') + k = te.reduce_axis((0, K), name='k') + C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C') + return [A, B, C] + """ + func_name = func.__name__ + if func_name in WORKLOAD_FUNC_REGISTRY: + raise RuntimeError('%s has been registered already' % func_name) + WORKLOAD_FUNC_REGISTRY[func_name] = func + return func + + +def make_workload_key_by_func(func, args): + """ make a workload key from function and arguments. + + Parameters + ---------- + func : Function + The target function that returns the compute declaration Tensors. + args : Args + The args of the target function. + + Returns + ------- + workload_key : Str + The workload key of the target function. + """ + args = serialize_args(args) + + if callable(func): + func_name = func.__name__ + elif isinstance(func, str): + func_name = func + else: + raise ValueError("Invalid function: " + str(func)) + + if not func_name in WORKLOAD_FUNC_REGISTRY: + raise ValueError("%s is not registered. " % func, + "Please register it with @ansor.register_workload_by_func") + + return json.dumps((func_name,) + args) + + +def decode_workload_key_to_func_args(workload_key): + """ Decode a workload key to the registerd function name and its corresponding args. + + Parameters + ---------- + workload_key : str + The target workload key. + + Returns + ------- + name : str + The function name of this workload key. + args : List[Tensor] + The args of the generation function. + """ + workload = json.loads(workload_key) + if not workload[0] in WORKLOAD_FUNC_REGISTRY: + raise ValueError("%s is not registered. " % workload[0] + + "Please register it with @ansor.register_workload_by_func") + return workload[0], deserialize_args(workload[1:]) + + +@tvm._ffi.register_func("ansor.workload_key_to_tensors") +def workload_key_to_tensors(workload_key): + """ Get the input/output tensors from the workload key. + + This method is usually used to create a ComputeDAG by workload key. + + Parameters + ---------- + workload_key : str + The target workload key. + + Returns + ------- + tensors : List[Tensor] + The registered compute declaration Tensors. + """ + name, args = decode_workload_key_to_func_args(workload_key) + lookup = WORKLOAD_FUNC_REGISTRY[name] + assert callable(lookup) + return lookup(*args) + + +def dump_workload_func_registry(filename): Review comment: save should be better because you have load before. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
