jcf94 commented on a change in pull request #5962: URL: https://github.com/apache/incubator-tvm/pull/5962#discussion_r449406446
########## File path: python/tvm/ansor/loop_state.py ########## @@ -0,0 +1,211 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import + +""" +The definition of the "state" in search. A state consists a current loop structure +and the transform history to reach its current loop structure. +To enable flexible manipulation of the loop structures, we implemented a lightweight loop +structure IR (Intermediate Representation) based on the original TVM IR but specifically +for schedule search. + +We don't use the existing TVM IR but to extend a new Sketch IR on it is because: +1. We want fast incremental change to the loop structures; +2. We want serializable transform history for replay, backtracking, and mutation; +3. We may create some macro schedule primitives that represent the combination of several +TVM schedule primitives. + +After the search is done, we will lower this IR to TVM IR with TVM's schedule primitives. +Because we share a lot common objects during search, the transformation is implemented in +copy on write style. All objects are immutable, which is similar to TVM IR. +""" + +import tvm._ffi +from tvm.te.tensor import Operation, Tensor +from tvm.runtime import Object +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.Iterator") +class Iterator(Object): + """ A loop iterator structure. """ + + +@tvm._ffi.register_object("ansor.Stage") +class Stage(Object): + """A stage in the compute declaration. Similar to tvm.te.schedule.Stage""" + + +@tvm._ffi.register_object("ansor.State") +class StateObject(Object): + """ The internal State object """ + def __eq__(self, other): + return _ffi_api.StateEqual(self, other) + + +class State: + """ + A state in the search process. It consists of the current loop structure + and the history steps to reach this state. + + Each State corresponds to a specific schedule for the target ComputeDAG. + + Parameters + ---------- + state_object : StateObject + The target StateObject, corresponding to C++ internal State object. + dag : ComputeDAG + The original target ComputeDAG of this State. + + Notes + ----- + This is a wrapper class of StateObject to deal with copy-on-write property + """ + def __init__(self, state_object, dag): + self.state_object = state_object + self.compute_dag = dag + + self.stages_cache = None # A list to cache all stages Review comment: The wrap implementation here is trying to make the State Python APIs similar to the existing TVM primitive APIs. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
