junrushao1994 commented on a change in pull request #5962: URL: https://github.com/apache/incubator-tvm/pull/5962#discussion_r449831169
########## File path: src/ansor/search_policy/search_policy.cc ########## @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/search_policy/search_policy.cc + * \brief The base class of search policies. + */ + +#include "search_policy.h" + +#include <tvm/runtime/registry.h> + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(SearchCallbackNode); +TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); + +void SearchPolicyNode::RunCallbacks(const Array<SearchCallback>& callbacks) { + if (callbacks.defined() && callbacks.size()) { + for (const auto& callback : callbacks) { + callback->Callback(this); + } + } +} Review comment: Use `tvm::runtime::Optional` to indicate nullable value. ```suggestion void SearchPolicyNode::RunCallbacks(const Optional<Array<SearchCallback>>& callbacks) { if (callbacks.defined()) { for (const auto& callback : callbacks.value()) { callback->Callback(this); } } } ``` ########## File path: src/ansor/record.cc ########## @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/record.cc + * \brief Json serialization format for dumping and loading tuning records. + */ + +#include "record.h" + +#include <dmlc/json.h> +#include <tvm/runtime/registry.h> + +#include <fstream> +#include <sstream> +#include <string> +#include <utility> +#include <vector> + +#include "loop_state.h" +#include "transform_step.h" +#include "utils.h" + +// Json serialization handler for MeasureInput, MeasureResult +// (and recursively for SearchTask, State, Step, ...) +namespace dmlc { +namespace json { + +inline std::vector<int>& IntArrayToVector(std::vector<int>* out, + const ::tvm::Array<::tvm::Integer>& data) { + out->clear(); + for (const auto& x : data) { + CHECK(x.defined()); + out->push_back(x); + } + return *out; +} Review comment: I think it is okay to define `out` inside this function and return it, don't have to pass `out` in. Copy elision can handle it properly ########## File path: src/ansor/loop_state.cc ########## @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/loop_state.cc + * \brief An lightweight IR (intermediate representation) for loop structures. + * see ansor/loop_state.h for more explanation. + */ + +#include "loop_state.h" + +#include <tvm/runtime/registry.h> +#include <tvm/te/operation.h> + +#include <utility> + +#include "transform_step.h" +#include "utils.h" + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(StepNode); +TVM_REGISTER_NODE_TYPE(StageNode); +TVM_REGISTER_NODE_TYPE(StateNode); +TVM_REGISTER_NODE_TYPE(IteratorNode); + +/********** Iterator **********/ +Iterator::Iterator(String name, Range range, IteratorType iter_type, + IteratorAnnotation annotation) { + auto node = make_object<IteratorNode>(); + node->name = std::move(name); + node->range = std::move(range); + node->iter_type = iter_type; + node->annotation = annotation; + data_ = std::move(node); +} + +/********** Stage **********/ +Stage::Stage(te::Operation op) { + auto node = make_object<StageNode>(); + if (op->IsInstance<te::ComputeOpNode>()) { + node->op_type = kCompute; + auto* pop = op.as<te::ComputeOpNode>(); + for (const auto& axis : pop->axis) { + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kSpace, kNone)); + } + for (const auto& axis : pop->reduce_axis) { + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kReduce, kNone)); + } + } else if (op->IsInstance<te::PlaceholderOpNode>()) { + node->op_type = kPlaceholder; + } else { + LOG(FATAL) << "Unsupported operator type" << op->_type_key; + } + + node->compute_at = kRoot; + node->op = std::move(op); + node->attrs.auto_unroll_max_step = 0; + node->attrs.storage_offset = 0; + data_ = std::move(node); +} + +Stage::Stage(te::Operation op, StageType op_type, const Array<Iterator>& iters, + ComputeAtType compute_at, StageAttributes attrs) { + auto node = make_object<StageNode>(); + node->op = std::move(op); + node->op_type = op_type; + node->iters = iters; + node->compute_at = compute_at; + node->attrs = attrs; + data_ = std::move(node); +} + +Stage::Stage(te::Operation op, StageType op_type, Array<Iterator>&& iters, ComputeAtType compute_at, + StageAttributes attrs) { + auto node = make_object<StageNode>(); + node->op = std::move(op); + node->op_type = op_type; + node->iters = std::move(iters); + node->compute_at = compute_at; + node->attrs = attrs; + data_ = std::move(node); +} + +/********** State **********/ +State::State(const Array<te::Operation>& ops) { + auto node = make_object<StateNode>(); + for (const auto& op : ops) { + node->stages.push_back(Stage(op)); + } + node->complete = true; + data_ = std::move(node); +} + +/********** Schedule primitives apis for state **********/ +void State::reorder(int stage_id, const Array<Iterator>& order) { + const Stage& stage = operator->()->stages[stage_id]; + CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " + << "should be specified"; + Array<Integer> after_ids; + GetIndices(stage->iters, order, &after_ids); + ReorderStep step = ReorderStep(stage_id, after_ids); + CopyOnWrite()->transform_steps.push_back(step); + DoReorderStep(step); +} + +Array<Iterator> State::split(int stage_id, const Iterator& it, const Array<Integer>& lengths, + bool inner_to_outer) { + const Stage& stage = operator->()->stages[stage_id]; + SplitStep step = + SplitStep(stage_id, GetIndex(stage->iters, it), + it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer); + CopyOnWrite()->transform_steps.push_back(step); + return DoSplitStep(step); +} + +Iterator State::fuse(int stage_id, const Array<Iterator>& iters) { + const Stage& stage = operator->()->stages[stage_id]; + Array<Integer> indices; + GetIndices(stage->iters, iters, &indices); + FuseStep step = FuseStep(stage_id, indices); + CopyOnWrite()->transform_steps.push_back(step); + return DoFuseStep(step); +} + +/********** Step implementations for state **********/ +void State::DoReorderStep(const ReorderStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + Array<Iterator> iters; + for (auto x : step->after_ids) { + iters.push_back(stage->iters[x]); + } + StateNode* pstate = CopyOnWrite(); + pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(iters), + stage->compute_at, stage->attrs)); +} + +// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep +Array<Iterator> State::DoSplitStepCommon(int stage_id, int iter_id, const Array<Integer>& lengths, + bool inner_to_outer) { + const Stage& stage = operator->()->stages[stage_id]; + const Iterator& it = stage->iters[iter_id]; + + PrimExpr tosplit_min, tosplit_extent; Review comment: ```suggestion Optional<PrimExpr> tosplit_min, tosplit_extent; ``` ########## File path: src/ansor/loop_state.cc ########## @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/loop_state.cc + * \brief An lightweight IR (intermediate representation) for loop structures. + * see ansor/loop_state.h for more explanation. + */ + +#include "loop_state.h" + +#include <tvm/runtime/registry.h> +#include <tvm/te/operation.h> + +#include <utility> + +#include "transform_step.h" +#include "utils.h" + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(StepNode); +TVM_REGISTER_NODE_TYPE(StageNode); +TVM_REGISTER_NODE_TYPE(StateNode); +TVM_REGISTER_NODE_TYPE(IteratorNode); + +/********** Iterator **********/ +Iterator::Iterator(String name, Range range, IteratorType iter_type, + IteratorAnnotation annotation) { + auto node = make_object<IteratorNode>(); + node->name = std::move(name); + node->range = std::move(range); + node->iter_type = iter_type; + node->annotation = annotation; + data_ = std::move(node); +} + +/********** Stage **********/ +Stage::Stage(te::Operation op) { + auto node = make_object<StageNode>(); + if (op->IsInstance<te::ComputeOpNode>()) { + node->op_type = kCompute; + auto* pop = op.as<te::ComputeOpNode>(); + for (const auto& axis : pop->axis) { + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kSpace, kNone)); + } + for (const auto& axis : pop->reduce_axis) { + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kReduce, kNone)); + } + } else if (op->IsInstance<te::PlaceholderOpNode>()) { + node->op_type = kPlaceholder; + } else { + LOG(FATAL) << "Unsupported operator type" << op->_type_key; + } + + node->compute_at = kRoot; + node->op = std::move(op); + node->attrs.auto_unroll_max_step = 0; + node->attrs.storage_offset = 0; + data_ = std::move(node); +} + +Stage::Stage(te::Operation op, StageType op_type, const Array<Iterator>& iters, + ComputeAtType compute_at, StageAttributes attrs) { + auto node = make_object<StageNode>(); + node->op = std::move(op); + node->op_type = op_type; + node->iters = iters; + node->compute_at = compute_at; + node->attrs = attrs; + data_ = std::move(node); +} + +Stage::Stage(te::Operation op, StageType op_type, Array<Iterator>&& iters, ComputeAtType compute_at, + StageAttributes attrs) { + auto node = make_object<StageNode>(); + node->op = std::move(op); + node->op_type = op_type; + node->iters = std::move(iters); + node->compute_at = compute_at; + node->attrs = attrs; + data_ = std::move(node); +} + +/********** State **********/ +State::State(const Array<te::Operation>& ops) { + auto node = make_object<StateNode>(); + for (const auto& op : ops) { + node->stages.push_back(Stage(op)); + } + node->complete = true; + data_ = std::move(node); +} + +/********** Schedule primitives apis for state **********/ +void State::reorder(int stage_id, const Array<Iterator>& order) { + const Stage& stage = operator->()->stages[stage_id]; + CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " + << "should be specified"; + Array<Integer> after_ids; + GetIndices(stage->iters, order, &after_ids); + ReorderStep step = ReorderStep(stage_id, after_ids); + CopyOnWrite()->transform_steps.push_back(step); + DoReorderStep(step); +} + +Array<Iterator> State::split(int stage_id, const Iterator& it, const Array<Integer>& lengths, + bool inner_to_outer) { + const Stage& stage = operator->()->stages[stage_id]; + SplitStep step = + SplitStep(stage_id, GetIndex(stage->iters, it), + it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer); + CopyOnWrite()->transform_steps.push_back(step); + return DoSplitStep(step); +} + +Iterator State::fuse(int stage_id, const Array<Iterator>& iters) { + const Stage& stage = operator->()->stages[stage_id]; + Array<Integer> indices; + GetIndices(stage->iters, iters, &indices); + FuseStep step = FuseStep(stage_id, indices); + CopyOnWrite()->transform_steps.push_back(step); + return DoFuseStep(step); +} + +/********** Step implementations for state **********/ +void State::DoReorderStep(const ReorderStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + Array<Iterator> iters; + for (auto x : step->after_ids) { + iters.push_back(stage->iters[x]); + } + StateNode* pstate = CopyOnWrite(); + pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(iters), + stage->compute_at, stage->attrs)); +} + +// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep +Array<Iterator> State::DoSplitStepCommon(int stage_id, int iter_id, const Array<Integer>& lengths, + bool inner_to_outer) { + const Stage& stage = operator->()->stages[stage_id]; + const Iterator& it = stage->iters[iter_id]; + + PrimExpr tosplit_min, tosplit_extent; + if (it->range.defined()) { + tosplit_min = it->range->min; + tosplit_extent = it->range->extent; + } else { + tosplit_min = tosplit_extent = PrimExpr(); + } + + Array<Iterator> outs; + for (size_t i = 0; i < lengths.size(); ++i) { + PrimExpr l; + String name; + if (inner_to_outer) { + l = lengths[lengths.size() - i - 1]; + name = it->name + "." + std::to_string(lengths.size() - i); + } else { + l = lengths[i]; + name = it->name + "." + std::to_string(i); + } + Iterator res; + if (l.defined() && tosplit_min.defined() && tosplit_extent.defined()) { + res = Iterator(name, Range::FromMinExtent(tosplit_min, l), it->iter_type, kNone); + tosplit_min = 0; + tosplit_extent = indexdiv(tosplit_extent + l - 1, l); + } else { + res = Iterator(name, Range(), it->iter_type, kNone); + tosplit_min = tosplit_extent = PrimExpr(); + } + outs.push_back(std::move(res)); + } + + Range range; + if (tosplit_min.defined() && tosplit_extent.defined()) { + range = Range::FromMinExtent(tosplit_min, tosplit_extent); + } + if (inner_to_outer) { + outs.push_back(Iterator(it->name + ".0", range, it->iter_type, kNone)); + // Reverse the Iterator array + Array<Iterator> temp(outs.rbegin(), outs.rend()); + outs = std::move(temp); + } else { + outs.push_back( + Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_type, kNone)); + } + + Array<Iterator> new_iters; + new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); + new_iters.insert(new_iters.end(), outs.begin(), outs.end()); + new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end()); + + StateNode* pstate = CopyOnWrite(); + pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), + stage->compute_at, stage->attrs)); + + return outs; +} + +Array<Iterator> State::DoSplitStep(const SplitStep& step) { + return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths, step->inner_to_outer); +} + +Iterator State::DoFuseStep(const FuseStep& step) { + int stage_id = step->stage_id; + const Stage& stage = operator->()->stages[stage_id]; + + String new_name; + PrimExpr new_extent = 1; + IteratorType new_iter_type = kSpecial; + + for (size_t i = 0; i < step->fused_ids.size(); ++i) { + if (i > 0) { + CHECK_EQ(step->fused_ids[i]->value, step->fused_ids[i - 1]->value + 1); + } + + const Iterator& it = stage->iters[step->fused_ids[i]]; + new_name = new_name + it->name + "@"; + + if (it->range.defined() && new_extent.defined()) { + new_extent = new_extent * it->range->extent; + } else { + new_extent = PrimExpr(); + } + + if (i == 0) { + new_iter_type = it->iter_type; + } else { + if (new_iter_type != it->iter_type) { + new_iter_type = kMixed; + } + } + } + + Range range; + if (new_extent.defined()) { + range = Range::FromMinExtent(0, new_extent); + } + Iterator new_it = Iterator(new_name, range, new_iter_type, kNone); + Array<Iterator> new_iters; + new_iters.insert(new_iters.end(), stage->iters.begin(), + stage->iters.begin() + step->fused_ids.front()); + new_iters.push_back(new_it); + new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back() + 1, + stage->iters.end()); + + StateNode* pstate = CopyOnWrite(); + pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters), + stage->compute_at, stage->attrs)); + + return new_it; +} + +void State::DoSteps(const ComputeDAG& dag) { Review comment: Is `dag` unused? ########## File path: src/ansor/search_policy/search_policy.cc ########## @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/search_policy/search_policy.cc + * \brief The base class of search policies. + */ + +#include "search_policy.h" + +#include <tvm/runtime/registry.h> + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(SearchCallbackNode); +TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); Review comment: Same here ########## File path: src/ansor/loop_state.h ########## @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/loop_state.h + * \brief The definition of the "state" in search. + * + * Each LoopState corresponds to a schedule for its ComputeDAG. + * A LoopState consists of: 1. a current loop structure; 2. a list of transformation steps used to + * construct the loop structure. + * The loop structure keeps a preview of how the schedule will finally look like after lowering the + * current state (e.g. number of iterators, the extent of each iterator, the compute_at locations + * ...). + * During the schedule search process, the loop structure can provide search policy with necessary + * information on how to manipulate the current state. + * The transform history is a sequence of `TransformStep` which will finally be mapped to TVM + * schedule primitives. The steps can also be used for the serialization of a state. + * + * The LoopState can be seen as a lightweight loop structure IR specifically for schedule search. + * We don't use the existing TVM IR but to extend a new structure on it is because: + * 1. We want fast incremental change to the loop structures. The search policy needs to get the + * immediate loop structures update rather than after TVM lowering; + * 2. We want serializable transform history for replay, backtracking, and mutation; + * 3. We may create some macro schedule primitives that represent the combination of several + * TVM schedule primitives. + * + * When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives. + * Since we share a lot of common objects during search, the transformation is implemented in + * copy on write style. All objects are immutable, which is similar to TVM IR. + */ + +#ifndef TVM_ANSOR_LOOP_STATE_H_ +#define TVM_ANSOR_LOOP_STATE_H_ + +#include <tvm/runtime/container.h> + +#include <functional> + +#include "transform_step.h" + +namespace tvm { +namespace ansor { + +using namespace tvm::tir; + +class ComputeDAG; + +/*! \brief The type of a stage. */ +enum StageType { + /*! \brief A placeholder stage. */ + kPlaceholder = 0, + /*! \brief A compute stage. */ + kCompute = 1 +}; + +/*! \brief The type of compute location. */ +enum ComputeAtType { + /*! \brief Compute at root. */ + kRoot = 0, + /*! \brief Compute inlined. */ + kInlined = 1, + /*! \brief Compute at some iterator. */ + kIter = 2, +}; + +/*! \brief The type of an iterator. */ +enum IteratorType { + /*! \brief Spatial iterator. */ + kSpace = 0, + /*! \brief Reduction iterator. */ + kReduce = 1, + /*! \brief Fused spatial and reduction iterator. */ + kMixed = 2, + /*! \brief Special iterator. (e.g. virtual root iterator) */ + kSpecial = 3 +}; + +/*! \brief The type of an iterator's annotation. */ +enum IteratorAnnotation { + /*! \brief This iterator has no annotation. */ + kNone = 0, + /*! \brief This iterator has been unrolled. */ + kUnroll = 1, + /*! \brief This iterator has been vectorized. */ + kVectorize = 2, + /*! \brief This iterator has been paralleld. */ + kParallel = 3, + /*! \brief This iterator has been bind to vthread. */ + kVThread = 4, + /*! \brief This iterator has been bind to blockIdx.x. */ + kBlockX = 5, + /*! \brief This iterator has been bind to threadIdx.x. */ + kThreadX = 6, + /*! \brief This iterator has been bind to blockIdx.y. */ + kBlockY = 7, + /*! \brief This iterator has been bind to threadIdx.y. */ + kThreadY = 8, + /*! \brief This iterator has been mapped with a tensorize intrinsic. */ + kTensorized = 9 +}; + +/*! + * \brief A for loop iterator + * Similar to tvm::IterVar in `include/tvm/tir/expr.h` + */ +class IteratorNode : public Object { + public: + /*! \brief The name of this iterator. */ + String name; + /*! \brief The range of this iterator. */ + Range range; + /*! \brief The iterator type of this iterator. */ + IteratorType iter_type; + /*! \brief The annotation type of this iterator. */ + IteratorAnnotation annotation; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("range", &range); + } + + static constexpr const char* _type_key = "ansor.Iterator"; + TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); +}; + +/*! + * \brief Managed reference to IteratorNode. + * \sa IteratorNode + */ +class Iterator : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param name The name of this iterator. + * \param range The range of this iterator. + * \param iter_type The iterator type of this iterator. + * \param annotation The annotation type of this iterator. + */ + Iterator(String name, Range range, IteratorType iter_type, IteratorAnnotation annotation); + + TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); +}; + +/*! \brief Stage-level attributes. */ +struct StageAttributes { + /*! \brief The maximum steps for the pragma `auto_unroll_max_step`. */ + int auto_unroll_max_step; + /*! \brief The storage offset for the schedule primitive `storage_align`. */ + int storage_offset; +}; + +/*! + * \brief A op stage in the compute declaration. + * Similar to te::Stage in `include/schedule.h`. + */ +class StageNode : public Object { + public: + /*! \brief The operator of this stage */ + te::Operation op; + /*! \brief The type of this stage. */ + StageType op_type; + /*! \brief The iterators in this stage. */ + Array<Iterator> iters; + /*! \brief The compute location of this stage. */ + ComputeAtType compute_at; + /*! \brief Other stage-level attributes. */ + StageAttributes attrs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("op", &op); + v->Visit("iters", &iters); + } + + static constexpr const char* _type_key = "ansor.Stage"; + TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); +}; + +/*! + * \brief Managed reference to StageNode. + * \sa StageNode + */ +class Stage : public ObjectRef { + public: + /*! + * \brief The constructor. + * \param op A `te::Operation`. + */ + explicit Stage(te::Operation op); + /*! + * \brief The constructor. + * \param op A `te::Operation`. + * \param op_type The stage type of this op. + * \param iters The iterators of this op. (copy) + * \param compute_at The compute at type of this op. + * \param attrs Other stage-level attributes. + */ + Stage(te::Operation op, StageType op_type, const Array<Iterator>& iters, ComputeAtType compute_at, + StageAttributes attrs); + /*! + * \brief The constructor. + * \param op A `te::Operation`. + * \param op_type The stage type of this op. + * \param iters The iterators of this op. (move) + * \param compute_at The compute at type of this op. + * \param attrs Other stage-level attributes. + */ + Stage(te::Operation op, StageType op_type, Array<Iterator>&& iters, ComputeAtType compute_at, Review comment: Moving semantics are not really useful in this case: `Array<>` is a simple ref counted pointer. ########## File path: src/ansor/loop_state.cc ########## @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/loop_state.cc + * \brief An lightweight IR (intermediate representation) for loop structures. + * see ansor/loop_state.h for more explanation. + */ + +#include "loop_state.h" + +#include <tvm/runtime/registry.h> +#include <tvm/te/operation.h> + +#include <utility> + +#include "transform_step.h" +#include "utils.h" + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(StepNode); +TVM_REGISTER_NODE_TYPE(StageNode); +TVM_REGISTER_NODE_TYPE(StateNode); +TVM_REGISTER_NODE_TYPE(IteratorNode); + +/********** Iterator **********/ +Iterator::Iterator(String name, Range range, IteratorType iter_type, + IteratorAnnotation annotation) { + auto node = make_object<IteratorNode>(); + node->name = std::move(name); + node->range = std::move(range); + node->iter_type = iter_type; + node->annotation = annotation; + data_ = std::move(node); +} + +/********** Stage **********/ +Stage::Stage(te::Operation op) { + auto node = make_object<StageNode>(); + if (op->IsInstance<te::ComputeOpNode>()) { + node->op_type = kCompute; + auto* pop = op.as<te::ComputeOpNode>(); + for (const auto& axis : pop->axis) { + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kSpace, kNone)); + } + for (const auto& axis : pop->reduce_axis) { + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kReduce, kNone)); + } + } else if (op->IsInstance<te::PlaceholderOpNode>()) { + node->op_type = kPlaceholder; + } else { + LOG(FATAL) << "Unsupported operator type" << op->_type_key; + } + + node->compute_at = kRoot; + node->op = std::move(op); + node->attrs.auto_unroll_max_step = 0; + node->attrs.storage_offset = 0; + data_ = std::move(node); +} + +Stage::Stage(te::Operation op, StageType op_type, const Array<Iterator>& iters, + ComputeAtType compute_at, StageAttributes attrs) { + auto node = make_object<StageNode>(); + node->op = std::move(op); + node->op_type = op_type; + node->iters = iters; + node->compute_at = compute_at; + node->attrs = attrs; + data_ = std::move(node); +} + +Stage::Stage(te::Operation op, StageType op_type, Array<Iterator>&& iters, ComputeAtType compute_at, + StageAttributes attrs) { + auto node = make_object<StageNode>(); + node->op = std::move(op); + node->op_type = op_type; + node->iters = std::move(iters); + node->compute_at = compute_at; + node->attrs = attrs; + data_ = std::move(node); +} + +/********** State **********/ +State::State(const Array<te::Operation>& ops) { + auto node = make_object<StateNode>(); + for (const auto& op : ops) { + node->stages.push_back(Stage(op)); + } + node->complete = true; + data_ = std::move(node); +} + +/********** Schedule primitives apis for state **********/ +void State::reorder(int stage_id, const Array<Iterator>& order) { + const Stage& stage = operator->()->stages[stage_id]; + CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " + << "should be specified"; + Array<Integer> after_ids; + GetIndices(stage->iters, order, &after_ids); + ReorderStep step = ReorderStep(stage_id, after_ids); + CopyOnWrite()->transform_steps.push_back(step); + DoReorderStep(step); +} + +Array<Iterator> State::split(int stage_id, const Iterator& it, const Array<Integer>& lengths, + bool inner_to_outer) { + const Stage& stage = operator->()->stages[stage_id]; + SplitStep step = + SplitStep(stage_id, GetIndex(stage->iters, it), + it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer); + CopyOnWrite()->transform_steps.push_back(step); + return DoSplitStep(step); +} + +Iterator State::fuse(int stage_id, const Array<Iterator>& iters) { + const Stage& stage = operator->()->stages[stage_id]; + Array<Integer> indices; + GetIndices(stage->iters, iters, &indices); + FuseStep step = FuseStep(stage_id, indices); + CopyOnWrite()->transform_steps.push_back(step); + return DoFuseStep(step); +} + +/********** Step implementations for state **********/ +void State::DoReorderStep(const ReorderStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + Array<Iterator> iters; + for (auto x : step->after_ids) { + iters.push_back(stage->iters[x]); + } + StateNode* pstate = CopyOnWrite(); + pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(iters), + stage->compute_at, stage->attrs)); +} + +// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep +Array<Iterator> State::DoSplitStepCommon(int stage_id, int iter_id, const Array<Integer>& lengths, + bool inner_to_outer) { + const Stage& stage = operator->()->stages[stage_id]; + const Iterator& it = stage->iters[iter_id]; + + PrimExpr tosplit_min, tosplit_extent; + if (it->range.defined()) { + tosplit_min = it->range->min; + tosplit_extent = it->range->extent; + } else { + tosplit_min = tosplit_extent = PrimExpr(); Review comment: ```suggestion tosplit_min = NullOpt; tosplit_extent = NullOpt; ``` ########## File path: src/ansor/loop_state.cc ########## @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/loop_state.cc + * \brief An lightweight IR (intermediate representation) for loop structures. + * see ansor/loop_state.h for more explanation. + */ + +#include "loop_state.h" + +#include <tvm/runtime/registry.h> +#include <tvm/te/operation.h> + +#include <utility> + +#include "transform_step.h" +#include "utils.h" + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(StepNode); +TVM_REGISTER_NODE_TYPE(StageNode); +TVM_REGISTER_NODE_TYPE(StateNode); +TVM_REGISTER_NODE_TYPE(IteratorNode); + +/********** Iterator **********/ +Iterator::Iterator(String name, Range range, IteratorType iter_type, + IteratorAnnotation annotation) { + auto node = make_object<IteratorNode>(); + node->name = std::move(name); + node->range = std::move(range); + node->iter_type = iter_type; + node->annotation = annotation; + data_ = std::move(node); +} + +/********** Stage **********/ +Stage::Stage(te::Operation op) { + auto node = make_object<StageNode>(); + if (op->IsInstance<te::ComputeOpNode>()) { + node->op_type = kCompute; + auto* pop = op.as<te::ComputeOpNode>(); + for (const auto& axis : pop->axis) { + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kSpace, kNone)); + } + for (const auto& axis : pop->reduce_axis) { + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kReduce, kNone)); + } + } else if (op->IsInstance<te::PlaceholderOpNode>()) { + node->op_type = kPlaceholder; + } else { + LOG(FATAL) << "Unsupported operator type" << op->_type_key; + } + + node->compute_at = kRoot; + node->op = std::move(op); + node->attrs.auto_unroll_max_step = 0; + node->attrs.storage_offset = 0; + data_ = std::move(node); +} + +Stage::Stage(te::Operation op, StageType op_type, const Array<Iterator>& iters, + ComputeAtType compute_at, StageAttributes attrs) { + auto node = make_object<StageNode>(); + node->op = std::move(op); + node->op_type = op_type; + node->iters = iters; + node->compute_at = compute_at; + node->attrs = attrs; + data_ = std::move(node); +} + +Stage::Stage(te::Operation op, StageType op_type, Array<Iterator>&& iters, ComputeAtType compute_at, + StageAttributes attrs) { + auto node = make_object<StageNode>(); + node->op = std::move(op); + node->op_type = op_type; + node->iters = std::move(iters); + node->compute_at = compute_at; + node->attrs = attrs; + data_ = std::move(node); +} + +/********** State **********/ +State::State(const Array<te::Operation>& ops) { + auto node = make_object<StateNode>(); + for (const auto& op : ops) { + node->stages.push_back(Stage(op)); + } + node->complete = true; + data_ = std::move(node); +} + +/********** Schedule primitives apis for state **********/ +void State::reorder(int stage_id, const Array<Iterator>& order) { + const Stage& stage = operator->()->stages[stage_id]; + CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " + << "should be specified"; + Array<Integer> after_ids; + GetIndices(stage->iters, order, &after_ids); + ReorderStep step = ReorderStep(stage_id, after_ids); + CopyOnWrite()->transform_steps.push_back(step); + DoReorderStep(step); +} + +Array<Iterator> State::split(int stage_id, const Iterator& it, const Array<Integer>& lengths, + bool inner_to_outer) { + const Stage& stage = operator->()->stages[stage_id]; + SplitStep step = + SplitStep(stage_id, GetIndex(stage->iters, it), + it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer); + CopyOnWrite()->transform_steps.push_back(step); + return DoSplitStep(step); +} + +Iterator State::fuse(int stage_id, const Array<Iterator>& iters) { + const Stage& stage = operator->()->stages[stage_id]; + Array<Integer> indices; + GetIndices(stage->iters, iters, &indices); + FuseStep step = FuseStep(stage_id, indices); + CopyOnWrite()->transform_steps.push_back(step); + return DoFuseStep(step); +} + +/********** Step implementations for state **********/ +void State::DoReorderStep(const ReorderStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + Array<Iterator> iters; + for (auto x : step->after_ids) { + iters.push_back(stage->iters[x]); + } + StateNode* pstate = CopyOnWrite(); + pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(iters), + stage->compute_at, stage->attrs)); +} + +// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep +Array<Iterator> State::DoSplitStepCommon(int stage_id, int iter_id, const Array<Integer>& lengths, + bool inner_to_outer) { + const Stage& stage = operator->()->stages[stage_id]; + const Iterator& it = stage->iters[iter_id]; + + PrimExpr tosplit_min, tosplit_extent; + if (it->range.defined()) { + tosplit_min = it->range->min; + tosplit_extent = it->range->extent; + } else { + tosplit_min = tosplit_extent = PrimExpr(); + } + + Array<Iterator> outs; + for (size_t i = 0; i < lengths.size(); ++i) { + PrimExpr l; + String name; + if (inner_to_outer) { + l = lengths[lengths.size() - i - 1]; + name = it->name + "." + std::to_string(lengths.size() - i); + } else { + l = lengths[i]; + name = it->name + "." + std::to_string(i); + } + Iterator res; + if (l.defined() && tosplit_min.defined() && tosplit_extent.defined()) { Review comment: Does it mean some integers in `lengths` are None? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
