merrymercy commented on a change in pull request #5962:
URL: https://github.com/apache/incubator-tvm/pull/5962#discussion_r449745627



##########
File path: python/tvm/ansor/loop_state.py
##########
@@ -0,0 +1,221 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import
+
+"""
+The definition of the "state" in search.
+
+Each LoopState corresponds to a specific schedule for its target ComputeDAG.
+A LoopState consists of: 1. a current loop structure; 2. a history of 
transformations used to
+construct the loop structure.
+The loop structure keeps a preview of how the schedule will finally look like 
after lowering the
+current state (e.g. number of iterators, the extent of each iterator, the 
compute_at locations ...).
+During the schedule search process, the loop structure can provide search 
policy with necessary
+information on how to perform further operations with the current state.
+The transform history is a sequence of TransformStep which will finally be 
mapped to schedule
+primitives. The steps can also be used for serialization of a state.
+
+The LoopState can be seen as a lightweight loop structure IR specifically for 
schedule search.
+We don't use the existing TVM IR but to extend a new structure on it is 
because:
+1. We want fast incremental change to the loop structures, search policy needs 
to get the immediate
+loop structures update rather than after TVM lowering;
+2. We want serializable transform history for replay, backtracking, and 
mutation;
+3. We may create some macro schedule primitives that represent the combination 
of several
+TVM schedule primitives.
+
+When the search is complete, we will lower the state to TVM IR with TVM's 
schedule primitives.
+Since we share a lot of common objects during search, the transformation is 
implemented in
+copy on write style. All objects are immutable, which is similar to TVM IR.
+"""

Review comment:
       Also, propagate the changes to c++ files.

##########
File path: python/tvm/ansor/loop_state.py
##########
@@ -0,0 +1,221 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import
+
+"""
+The definition of the "state" in search.
+
+Each LoopState corresponds to a specific schedule for its target ComputeDAG.
+A LoopState consists of: 1. a current loop structure; 2. a history of 
transformations used to
+construct the loop structure.
+The loop structure keeps a preview of how the schedule will finally look like 
after lowering the
+current state (e.g. number of iterators, the extent of each iterator, the 
compute_at locations ...).
+During the schedule search process, the loop structure can provide search 
policy with necessary
+information on how to perform further operations with the current state.
+The transform history is a sequence of TransformStep which will finally be 
mapped to schedule
+primitives. The steps can also be used for serialization of a state.
+
+The LoopState can be seen as a lightweight loop structure IR specifically for 
schedule search.
+We don't use the existing TVM IR but to extend a new structure on it is 
because:
+1. We want fast incremental change to the loop structures, search policy needs 
to get the immediate
+loop structures update rather than after TVM lowering;
+2. We want serializable transform history for replay, backtracking, and 
mutation;
+3. We may create some macro schedule primitives that represent the combination 
of several
+TVM schedule primitives.
+
+When the search is complete, we will lower the state to TVM IR with TVM's 
schedule primitives.
+Since we share a lot of common objects during search, the transformation is 
implemented in
+copy on write style. All objects are immutable, which is similar to TVM IR.
+"""
+
+import tvm._ffi
+from tvm.te.tensor import Operation, Tensor
+from tvm.runtime import Object
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.Iterator")
+class Iterator(Object):
+    """ A loop iterator structure. """
+
+
+@tvm._ffi.register_object("ansor.Stage")
+class Stage(Object):
+    """A stage in the compute declaration. Similar to tvm.te.schedule.Stage"""
+
+
+@tvm._ffi.register_object("ansor.State")
+class StateObject(Object):
+    """ The internal State object """
+    def __eq__(self, other):
+        return _ffi_api.StateEqual(self, other)
+
+
+class State:
+    """
+    A state in the search process. It consists of the current loop structure
+    and a history of transformations used to construct it.
+
+    Each State corresponds to a specific schedule for its target ComputeDAG.
+
+    Parameters
+    ----------
+    state_object : StateObject
+        The target StateObject, corresponding to C++ internal State object.

Review comment:
       Remove all "target" before "StateObject", "State" and "ComputeDAG" in 
this file.

##########
File path: python/tvm/ansor/loop_state.py
##########
@@ -0,0 +1,221 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import
+
+"""
+The definition of the "state" in search.
+
+Each LoopState corresponds to a specific schedule for its target ComputeDAG.
+A LoopState consists of: 1. a current loop structure; 2. a history of 
transformations used to
+construct the loop structure.
+The loop structure keeps a preview of how the schedule will finally look like 
after lowering the
+current state (e.g. number of iterators, the extent of each iterator, the 
compute_at locations ...).
+During the schedule search process, the loop structure can provide search 
policy with necessary
+information on how to perform further operations with the current state.
+The transform history is a sequence of TransformStep which will finally be 
mapped to schedule
+primitives. The steps can also be used for serialization of a state.
+
+The LoopState can be seen as a lightweight loop structure IR specifically for 
schedule search.
+We don't use the existing TVM IR but to extend a new structure on it is 
because:
+1. We want fast incremental change to the loop structures, search policy needs 
to get the immediate
+loop structures update rather than after TVM lowering;
+2. We want serializable transform history for replay, backtracking, and 
mutation;
+3. We may create some macro schedule primitives that represent the combination 
of several
+TVM schedule primitives.
+
+When the search is complete, we will lower the state to TVM IR with TVM's 
schedule primitives.
+Since we share a lot of common objects during search, the transformation is 
implemented in
+copy on write style. All objects are immutable, which is similar to TVM IR.
+"""
+
+import tvm._ffi
+from tvm.te.tensor import Operation, Tensor
+from tvm.runtime import Object
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.Iterator")
+class Iterator(Object):
+    """ A loop iterator structure. """
+
+
+@tvm._ffi.register_object("ansor.Stage")
+class Stage(Object):
+    """A stage in the compute declaration. Similar to tvm.te.schedule.Stage"""
+
+
+@tvm._ffi.register_object("ansor.State")
+class StateObject(Object):
+    """ The internal State object """
+    def __eq__(self, other):
+        return _ffi_api.StateEqual(self, other)
+
+
+class State:
+    """
+    A state in the search process. It consists of the current loop structure
+    and a history of transformations used to construct it.
+
+    Each State corresponds to a specific schedule for its target ComputeDAG.
+
+    Parameters
+    ----------
+    state_object : StateObject
+        The target StateObject, corresponding to C++ internal State object.
+    dag : ComputeDAG
+        The original target ComputeDAG of this State.
+
+    Notes
+    -----
+    This is a wrapper class of StateObject to deal with copy-on-write property
+    """
+    def __init__(self, state_object, dag):
+        self.state_object = state_object
+        self.compute_dag = dag
+
+        self.stages_cache = None  # A list to cache all stages
+        self.stage_id_map = {}    # A dict maps operation to stage id
+        self._update_stage_id_map()
+
+    @property
+    def stages(self):
+        """
+        Returns
+        -------
+        stages : List[Stage]
+        """
+        if not self.stages_cache:
+            self.stages_cache = self.state_object.stages
+        return self.stages_cache
+
+    @property
+    def stage_ops(self):
+        """
+        Returns
+        -------
+        ops: List[Operation]
+        """
+        if not self.stages_cache:
+            self.stages_cache = self.state_object.stages
+        return [stage.op for stage in self.stages_cache]
+
+    def reorder(self, stage, order):
+        """ Schedule primitive corresponds to te.reorder.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The target Stage to be reordered, can be a Stage order index, 
Stage operation or stage

Review comment:
       Remove all "target" before "stage". It is redundant.

##########
File path: python/tvm/ansor/measure.py
##########
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Distributed measurement infrastructure to measure the runtime costs of 
tensor programs
+
+These functions are responsible for building the tvm module, uploading it to
+remote devices, recording the running time costs, and checking the correctness 
of the output.
+
+We implement these in python to utilize python's multiprocessing and error 
handling
+"""
+
+import os
+import time
+import shutil
+import traceback
+import tempfile
+import multiprocessing
+
+import tvm._ffi
+from tvm.runtime import Object, module, ndarray
+from tvm.driver import build_module
+from tvm.ir import transform
+from tvm.contrib import tar, ndk
+
+from . import _ffi_api
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+
+# The maximum length of error message
+MAX_ERROR_MSG_LEN = 512
+
+# Global variables used in build function
+GLOBAL_BUILD_ARGUMENTS = None
+
+@tvm._ffi.register_object("ansor.MeasureCallback")
+class MeasureCallback(Object):
+    """ The base class of measurement callback functions. """
+
+
+@tvm._ffi.register_object("ansor.MeasureInput")
+class MeasureInput(Object):
+    """ Store the input of a measurement.
+
+    Parameters
+    ----------
+    task : SearchTask
+        The target SearchTask.
+    state : State
+        The current State to be measured.
+    """
+    def __init__(self, task, state):
+        self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, 
state.state_object)
+
+
+@tvm._ffi.register_object("ansor.BuildResult")
+class BuildResult(Object):
+    """ Store the result of a build.
+
+    Parameters
+    ----------
+    filename : Optional[str]
+        The filename of built binary file.
+    args : List[Tensor]
+        The arguments.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    time_cost : float
+        The time cost of build.
+    """
+    def __init__(self, filename, args, error_no, error_msg, time_cost):
+        filename = filename if filename else ""
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.BuildResult, filename, args, error_no, error_msg, 
time_cost)
+
+
+@tvm._ffi.register_object("ansor.MeasureResult")
+class MeasureResult(Object):
+    """ Store the results of a measurement.
+
+    Parameters
+    ----------
+    costs : List[float]
+        The time costs of execution.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    all_cost : float
+        The time cost of build and run.
+    timestamp : float
+        The time stamps of this measurement.
+    """
+    def __init__(self, costs, error_no, error_msg, all_cost, timestamp):
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.MeasureResult, costs, error_no,
+            error_msg, all_cost, timestamp)
+
+
+@tvm._ffi.register_object("ansor.ProgramBuilder")
+class ProgramBuilder(Object):
+    """ The base class of ProgramBuilders. """
+
+    def build(self, measure_inputs, verbose=1):
+        """ Build programs and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        verbost : int = 1
+            Verbosity level. 0 for silent, 1 to output information during 
program building.
+
+        Returns
+        -------
+        res : List[BuildResult]
+        """
+        return _ffi_api.ProgramBuilderBuild(self, measure_inputs, verbose)
+
+
+@tvm._ffi.register_object("ansor.ProgramRunner")
+class ProgramRunner(Object):
+    """ The base class of ProgramRunners. """
+
+    def run(self, measure_inputs, build_results, verbose=1):
+        """ Run measurement and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        build_results : List[BuildResult]
+            A List of BuildResult to be ran.
+        verbost : int = 1
+            Verbosity level. 0 for silent, 1 to output information during 
program running.
+
+        Returns
+        -------
+        res : List[MeasureResult]
+        """
+        return _ffi_api.ProgramRunnerRun(self, measure_inputs, build_results, 
verbose)
+
+
+@tvm._ffi.register_object("ansor.LocalBuilder")
+class LocalBuilder(ProgramBuilder):
+    """ LocalBuilder use local CPU cores to build programs in parallel.
+
+    Parameters
+    ----------
+    timeout : int = 15
+        The timeout limit for each build.
+    n_parallel : int = multiprocessing.cpu_count()
+        Number of threads used to build in parallel.
+    build_func : str = 'default'
+        The name of registered build function.
+    """
+
+    def __init__(self,
+                 timeout=15,
+                 n_parallel=multiprocessing.cpu_count(),
+                 build_func='default'):
+        self.__init_handle_by_constructor__(
+            _ffi_api.LocalBuilder, timeout, n_parallel, build_func)
+
+
+@tvm._ffi.register_object("ansor.LocalRunner")
+class LocalRunner(ProgramRunner):
+    """ LocalRunner that uses local CPU/GPU to measures the time cost of 
programs.
+
+    Parameters
+    ----------
+    timeout : int = 10
+        The timeout limit for each run.
+    number : int = 3
+        Number of measure times.
+    repeat : int = 1
+        Number of repeat times in each measure.
+    min_repeat_ms : int = 0
+        The minimum duration of one repeat in milliseconds.
+    cooldown_interval : float = 0.0
+        The cool down interval between two measurements.

Review comment:
       Also, add pointers to the docstring in c++ files.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to