yangjunpro commented on a change in pull request #5962:
URL: https://github.com/apache/incubator-tvm/pull/5962#discussion_r449876496



##########
File path: python/tvm/ansor/compute_dag.py
##########
@@ -0,0 +1,153 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Computational graph and its analysis tools """
+
+import hashlib
+
+import tvm._ffi
+from tvm.runtime import Object
+from tvm.te import PlaceholderOp, ComputeOp
+
+from .loop_state import State, StateObject
+from .utils import get_const_tuple
+from .workload_registry import workload_key_to_tensors
+
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.ComputeDAG")
+class ComputeDAG(Object):
+    """
+    The Ansor computational graph and related program analyses.
+
+    We convert a compute declaration described by `tvm.compute` (could be a 
single operator or a
+    subgraph) to a ComputeDAG. It keeps the input/output tensors of the target 
compute declaration,
+    a list of all related operations in topo order as well as a set of 
analyses over each operation
+    stage (e.g. the total float operation count, consumer/producer relations 
of each operation
+    stage, whether a operation stage should be tiled/compute inlined ...). 
These analyses can
+    help the search policy to do some specific decisions during schedule 
search process.
+
+    ComputeDAG is also responsible for the interaction between Ansor LoopState 
and TVM schedule
+    (e.g. applying the LoopState transform steps to TVM schedule, providing 
LoopState with extra
+    information get from TVM schedule ...).
+
+    Parameters
+    ----------
+    compute : Union[List[Tensor], str]
+        `Tensor`s or workload key for a compute declaration.
+    """
+    def __init__(self, compute):
+        if isinstance(compute, str):
+            compute = workload_key_to_tensors(compute)
+        elif isinstance(compute, list):
+            for item in compute:
+                if not isinstance(item, tvm.te.Tensor):
+                    raise ValueError("The input of ComputeDAG should be a list 
of Tensor")
+        else:
+            raise ValueError("Invalid compute: " + compute +
+                             " . `ComputeDAG` expects a string or list of 
Tensor")
+        self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute)
+
+    def get_init_state(self):
+        """ Get the init state of this ComputeDAG.
+
+        Returns
+        -------
+        state : State
+            The initial State without any transform steps.
+        """
+        return State(self.init_state, self)
+
+    def apply_steps_from_state(self, state):
+        """
+        Apply the history transform steps of a State to TVM schedule.
+
+        Parameters
+        ----------
+        state : Union[State, StateObject]
+            The target state to be applied to TVM schedule.
+
+        Returns
+        -------
+            A `te.schedule` and the target `te.Tensor`s to be used in 
`tvm.lower` or `tvm.build`
+        """
+        state_obj = state if isinstance(state, StateObject) else 
state.state_object
+        return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj)
+
+    def print_python_code_from_state(self, state):
+        """
+        Print transform steps in the history of a State as TVM's python 
schedule primitive.
+
+        Parameters
+        ----------
+        state : Union[State, StateObject]
+            The target state to be applied to TVM schedule.
+
+        Returns
+        -------
+        str : Str
+            The Python schedule code.

Review comment:
       +1

##########
File path: python/tvm/ansor/measure.py
##########
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Distributed measurement infrastructure to measure the runtime costs of 
tensor programs
+
+These functions are responsible for building the tvm module, uploading it to
+remote devices, recording the running time costs, and checking the correctness 
of the output.
+
+We implement these in python to utilize python's multiprocessing and error 
handling
+"""
+
+import os
+import time
+import shutil
+import traceback
+import tempfile
+import multiprocessing
+
+import tvm._ffi
+from tvm.runtime import Object, module, ndarray
+from tvm.driver import build_module
+from tvm.ir import transform
+from tvm.contrib import tar, ndk
+
+from . import _ffi_api
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+
+# The maximum length of error message
+MAX_ERROR_MSG_LEN = 512
+
+# Global variables used in build function
+GLOBAL_BUILD_ARGUMENTS = None
+
+@tvm._ffi.register_object("ansor.MeasureCallback")
+class MeasureCallback(Object):
+    """ The base class of measurement callback functions. """
+
+
+@tvm._ffi.register_object("ansor.MeasureInput")
+class MeasureInput(Object):
+    """ Store the input of a measurement.
+
+    Parameters
+    ----------
+    task : SearchTask
+        The target SearchTask.
+    state : State
+        The current State to be measured.
+    """
+    def __init__(self, task, state):
+        self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, 
state.state_object)
+
+
+@tvm._ffi.register_object("ansor.BuildResult")
+class BuildResult(Object):
+    """ Store the result of a build.
+
+    Parameters
+    ----------
+    filename : Optional[str]
+        The filename of built binary file.
+    args : List[Tensor]
+        The arguments.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    time_cost : float
+        The time cost of build.
+    """
+    def __init__(self, filename, args, error_no, error_msg, time_cost):
+        filename = filename if filename else ""
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.BuildResult, filename, args, error_no, error_msg, 
time_cost)
+
+
+@tvm._ffi.register_object("ansor.MeasureResult")
+class MeasureResult(Object):
+    """ Store the results of a measurement.
+
+    Parameters
+    ----------
+    costs : List[float]
+        The time costs of execution.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    all_cost : float
+        The time cost of build and run.
+    timestamp : float
+        The time stamps of this measurement.
+    """
+    def __init__(self, costs, error_no, error_msg, all_cost, timestamp):
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.MeasureResult, costs, error_no,
+            error_msg, all_cost, timestamp)
+
+
+@tvm._ffi.register_object("ansor.ProgramBuilder")
+class ProgramBuilder(Object):
+    """ The base class of ProgramBuilders. """
+
+    def build(self, measure_inputs, verbose=1):
+        """ Build programs and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        verbost : int = 1
+            Verbosity level. 0 for silent, 1 to output information during 
program building.
+
+        Returns
+        -------
+        res : List[BuildResult]
+        """
+        return _ffi_api.ProgramBuilderBuild(self, measure_inputs, verbose)
+
+
+@tvm._ffi.register_object("ansor.ProgramRunner")
+class ProgramRunner(Object):
+    """ The base class of ProgramRunners. """
+
+    def run(self, measure_inputs, build_results, verbose=1):
+        """ Run measurement and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        build_results : List[BuildResult]
+            A List of BuildResult to be ran.
+        verbost : int = 1
+            Verbosity level. 0 for silent, 1 to output information during 
program running.
+
+        Returns
+        -------
+        res : List[MeasureResult]
+        """
+        return _ffi_api.ProgramRunnerRun(self, measure_inputs, build_results, 
verbose)
+
+
+@tvm._ffi.register_object("ansor.LocalBuilder")
+class LocalBuilder(ProgramBuilder):
+    """ LocalBuilder use local CPU cores to build programs in parallel.
+
+    Parameters
+    ----------
+    timeout : int = 15
+        The timeout limit for each build.

Review comment:
       Seconds, it is actually a wrapper of the 
_multiprocessing.Process.join()_ method.
   Better to make it more specific in the comments. 

##########
File path: python/tvm/ansor/measure.py
##########
@@ -0,0 +1,466 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Distributed measurement infrastructure to measure the runtime costs of tensor 
programs.
+
+These functions are responsible for building the tvm module, uploading it to
+remote devices, recording the running time costs, and checking the correctness 
of the output.
+
+We implement these in python to utilize python's multiprocessing and error 
handling.
+"""
+
+import os
+import time
+import shutil
+import traceback
+import tempfile
+import multiprocessing
+
+import tvm._ffi
+from tvm.runtime import Object, module, ndarray
+from tvm.driver import build_module
+from tvm.ir import transform
+from tvm.contrib import tar, ndk
+
+from . import _ffi_api
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+
+# The maximum length of error message
+MAX_ERROR_MSG_LEN = 512
+
+# We use fork and a global variable to copy arguments between processings.
+# This can avoid expensive serialization of TVM IR when using 
multiprocessing.Pool
+GLOBAL_BUILD_ARGUMENTS = None
+
+@tvm._ffi.register_object("ansor.MeasureCallback")
+class MeasureCallback(Object):
+    """ The base class of measurement callback functions. """
+
+
+@tvm._ffi.register_object("ansor.MeasureInput")
+class MeasureInput(Object):
+    """ Store the input of a measurement.
+
+    Parameters
+    ----------
+    task : SearchTask
+        The SearchTask of this measure.
+    state : State
+        The State to be measured.
+    """
+    def __init__(self, task, state):
+        self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, 
state.state_object)
+
+
+@tvm._ffi.register_object("ansor.BuildResult")
+class BuildResult(Object):
+    """ Store the result of a build.
+
+    Parameters
+    ----------
+    filename : Optional[str]
+        The filename of built binary file.
+    args : List[Tensor]
+        The arguments.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    time_cost : float
+        The time cost of build.
+    """
+    def __init__(self, filename, args, error_no, error_msg, time_cost):
+        filename = filename if filename else ""
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.BuildResult, filename, args, error_no, error_msg, 
time_cost)
+
+
+@tvm._ffi.register_object("ansor.MeasureResult")
+class MeasureResult(Object):
+    """ Store the results of a measurement.
+
+    Parameters
+    ----------
+    costs : List[float]
+        The time costs of execution.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    all_cost : float
+        The time cost of build and run.
+    timestamp : float
+        The time stamps of this measurement.
+    """
+    def __init__(self, costs, error_no, error_msg, all_cost, timestamp):
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.MeasureResult, costs, error_no,
+            error_msg, all_cost, timestamp)
+
+
+@tvm._ffi.register_object("ansor.ProgramBuilder")
+class ProgramBuilder(Object):
+    """ The base class of ProgramBuilders. """
+
+    def build(self, measure_inputs, verbose=1):
+        """ Build programs and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        verbose : int = 1
+            Verbosity level. 0 for silent, 1 to output information during 
program building.
+
+        Returns
+        -------
+        res : List[BuildResult]
+        """
+        return _ffi_api.ProgramBuilderBuild(self, measure_inputs, verbose)
+
+
+@tvm._ffi.register_object("ansor.ProgramRunner")
+class ProgramRunner(Object):
+    """ The base class of ProgramRunners. """
+
+    def run(self, measure_inputs, build_results, verbose=1):
+        """ Run measurement and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        build_results : List[BuildResult]
+            A List of BuildResult to be ran.
+        verbose : int = 1
+            Verbosity level. 0 for silent, 1 to output information during 
program running.
+
+        Returns
+        -------
+        res : List[MeasureResult]
+        """
+        return _ffi_api.ProgramRunnerRun(self, measure_inputs, build_results, 
verbose)
+
+
+@tvm._ffi.register_object("ansor.LocalBuilder")
+class LocalBuilder(ProgramBuilder):
+    """ LocalBuilder use local CPU cores to build programs in parallel.
+
+    Parameters
+    ----------
+    timeout : int = 15
+        The timeout limit for each build.
+    n_parallel : int = multiprocessing.cpu_count()
+        Number of threads used to build in parallel.
+    build_func : str = 'default'
+        The name of registered build function.
+    """
+
+    def __init__(self,
+                 timeout=15,
+                 n_parallel=multiprocessing.cpu_count(),
+                 build_func='default'):
+        self.__init_handle_by_constructor__(
+            _ffi_api.LocalBuilder, timeout, n_parallel, build_func)
+
+
+@tvm._ffi.register_object("ansor.LocalRunner")
+class LocalRunner(ProgramRunner):
+    """ LocalRunner that uses local CPU/GPU to measures the time cost of 
programs.
+
+    Parameters
+    ----------
+    timeout : int = 10
+        The timeout limit for each run.
+    number : int = 3
+        The number of times to run the generated code for taking average.
+        We call these runs as one `repeat` of measurement.
+    repeat : int = 1
+        The number of times to repeat the measurement.
+        In total, the generated code will be run (1 + number x repeat) times,
+        where the first "1" is warm up and will be discarded.
+        The returned result contains `repeat` costs,
+        each of which is an average of `number` costs.
+    min_repeat_ms : int = 0
+        The minimum duration of one `repeat` in milliseconds.
+        By default, one `repeat` contains `number` runs. If this parameter is 
set,
+        the parameters `number` will be dynamically adjusted to meet the
+        minimum duration requirement of one `repeat`.
+        i.e., When the run time of one `repeat` falls below this time, the 
`number` parameter
+        will be automatically increased.
+    cooldown_interval : float = 0.0
+        The cool down interval between two measurements.
+    """
+
+    def __init__(self,
+                 timeout=10,
+                 number=3,
+                 repeat=1,
+                 min_repeat_ms=0,
+                 cooldown_interval=0.0):
+        self.__init_handle_by_constructor__(
+            _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, 
cooldown_interval)
+
+
+class MeasureErrorNo(object):
+    """ Error type for MeasureResult. """
+    NO_ERROR = 0              # No error
+    INSTANTIATION_ERROR = 1   # Errors happen when apply transform steps from 
init state
+                              # Errors happen when compiling code on host 
(e.g. tvm.build)
+    COMPILE_HOST = 2
+    COMPILE_DEVICE = 3        # Errors happen when compiling code on device
+                              # (e.g. OpenCL JIT on the device)
+    RUNTIME_DEVICE = 4        # Errors happen when run program on device
+    WRONG_ANSWER = 5          # Answer is wrong when compared to a reference 
output
+    BUILD_TIMEOUT = 6         # Timeout during compilation
+    RUN_TIMEOUT = 7           # Timeout during run
+    UNKNOWN_ERROR = 8         # Unknown error
+
+
+def make_error_msg():
+    """ Get the error message from traceback. """
+    error_msg = str(traceback.format_exc())
+    if len(error_msg) > MAX_ERROR_MSG_LEN:
+        error_msg = error_msg[:MAX_ERROR_MSG_LEN//2] + \

Review comment:
       What is the purpose to limit the length of the error message?

##########
File path: python/tvm/ansor/record.py
##########
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Serialization and other I/O support for tuning logs (measurement records). 
"""
+
+import numpy as np
+
+import tvm._ffi
+from tvm.runtime import Object
+from .measure import MeasureCallback, MeasureErrorNo
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.LogToFile")
+class LogToFile(MeasureCallback):
+    """
+    A measurement callback that writes measurement records into a file.
+
+    Parameters
+    ----------
+    filename : str
+        File name for this callback to write log to.
+    """
+    def __init__(self, filename="ansor_tuning.json"):
+        self.__init_handle_by_constructor__(_ffi_api.LogToFile, filename)
+
+
+@tvm._ffi.register_object("ansor.LogReader")
+class LogReader(Object):
+    """
+    Reader of the json log file.
+
+    Parameters
+    ----------
+    filename : str = "ansor_tuning.json"
+        File name for this reader to load log from.
+    """
+    def __init__(self, filename="ansor_tuning.json"):
+        self.__init_handle_by_constructor__(_ffi_api.LogReader, filename)
+
+    def read_lines(self, max_lines=None, skip_lines=None):

Review comment:
       Why not specify the _skip_line_ with _0_ as default value rather than 
_None_? Thus we can avoid a following condition check.
   So as _max_lines_.

##########
File path: python/tvm/ansor/utils.py
##########
@@ -0,0 +1,195 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Common utilities for ansor"""
+
+from typing import Hashable
+import multiprocessing
+import multiprocessing.pool
+import queue
+import signal
+
+try:
+    import psutil
+except ImportError:
+    raise ImportError("psutil not found, try `pip install psutil` to fix this")
+
+from tvm.tir import expr
+from tvm.tir.transform import Simplify
+from tvm.ir.transform import Sequential
+from ..te import Tensor, placeholder
+
+
+def get_func_name(func):
+    """Get name of a function.
+
+    Parameters
+    ----------
+    func: Function
+        The target function.
+
+    Returns
+    -------
+    name: str
+        The function name.
+    """
+    return func.func_name if hasattr(func, 'func_name') else func.__name__
+
+
+def get_const_int(exp):
+    """Verifies expr is integer and get the constant value.
+
+    Parameters
+    ----------
+    exp : tvm.Expr or int
+        The input expression.
+
+    Returns
+    -------
+    out_value : int
+        The output.
+    """
+    if isinstance(exp, int):
+        return exp
+    if not isinstance(exp, (expr.IntImm)):
+        opt = Sequential([Simplify()])
+        exp = opt(exp)
+    if not isinstance(exp, (expr.IntImm)):
+        raise ValueError("Expect value to be constant int")
+    return exp.value
+
+
+def get_const_tuple(in_tuple):
+    """Verifies input tuple is IntImm, returns tuple of int.
+
+    Parameters
+    ----------
+    in_tuple : tuple of Expr
+        The input.
+
+    Returns
+    -------
+    out_tuple : tuple of int
+        The output.
+    """
+    return tuple(get_const_int(x) for x in in_tuple)
+
+
+
+def list_to_tuple(x):
+    """ Convert a list to a tuple recursively. """
+    assert isinstance(x, list)
+    return tuple(list_to_tuple(y) if isinstance(y, list) else y for y in x)
+
+
+def serialize_args(args):
+    """
+    Serialize arguments of a function to a hashable and jsonable tuple.
+    Currently this is mainly used for tvm.tensor.Tensor
+    """
+    ret = []
+    for t in args:
+        if isinstance(t, Tensor):
+            t = ('TENSOR', get_const_tuple(t.shape), t.dtype)
+        elif isinstance(t, list):
+            t = list_to_tuple(t)
+
+        assert isinstance(t, Hashable), str(t) + " is not hashable"
+        ret.append(t)
+
+    return tuple(ret)
+
+
+def deserialize_args(args):
+    """The inverse function of :code:`serialize_args`"""
+    ret = []
+    for t in args:
+        if isinstance(t, (tuple, list)) and t[0] == 'TENSOR':
+            ret.append(placeholder(shape=t[1], dtype=t[2]))
+        else:
+            ret.append(t)
+    return ret
+
+
+class NoDaemonProcess(multiprocessing.Process):
+    @property
+    def daemon(self):
+        return False
+
+    @daemon.setter
+    def daemon(self, value):
+        pass
+
+
+class NoDaemonContext(type(multiprocessing.get_context())):
+    Process = NoDaemonProcess
+
+
+class NoDaemonPool(multiprocessing.pool.Pool):
+    """A no daemon pool version of multiprocessing.Pool.

Review comment:
       Maybe here what we mean is a synchronous pool rather than an 
asynchronous one? 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to