jcf94 commented on a change in pull request #6142:
URL: https://github.com/apache/incubator-tvm/pull/6142#discussion_r461261906
##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -111,6 +115,10 @@ void StepApplyToState(const Step& step, State* state,
const ComputeDAG& dag) {
ps->ApplyToState(state);
} else if (auto ps = step.as<SplitStepNode>()) {
ps->ApplyToState(state);
+ } else if (auto ps = step.as<FollowSplitStepNode>()) {
Review comment:
Yes, the problem is these functions may have different parameters/return
values, so they're not able to be merged to a single virtural function
interface.
##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -487,6 +490,164 @@ class SplitStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
};
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+ /*! \brief The id of the iter to be split. */
+ int iter_id;
+ /*! \brief The index of the split step to follow in the history. */
+ int src_step_id;
+ /*! \brief The number of split level. */
+ int n_split;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Extract split lengths.
+ * \param transform_steps An array record all transform steps.
+ * \param lengths The multiple split factors. Can be None to be filled by
search policy.
+ */
+ void ExtractSplitLengths(const Array<Step>& transform_steps,
+ Array<Optional<Integer>>* lengths) const;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
Review comment:
```suggestion
* \param state A mutable pointer to state, which will be updated.
* \return The iterator results after split.
```
##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -487,6 +490,164 @@ class SplitStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
};
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+ /*! \brief The id of the iter to be split. */
+ int iter_id;
+ /*! \brief The index of the split step to follow in the history. */
+ int src_step_id;
+ /*! \brief The number of split level. */
+ int n_split;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Extract split lengths.
+ * \param transform_steps An array record all transform steps.
+ * \param lengths The multiple split factors. Can be None to be filled by
search policy.
+ */
+ void ExtractSplitLengths(const Array<Step>& transform_steps,
+ Array<Optional<Integer>>* lengths) const;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
Review comment:
```suggestion
* \param state A mutable pointer to state, which will be updated.
* \return The iterator results after split.
```
##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -487,6 +490,164 @@ class SplitStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
};
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+ /*! \brief The id of the iter to be split. */
+ int iter_id;
+ /*! \brief The index of the split step to follow in the history. */
+ int src_step_id;
+ /*! \brief The number of split level. */
+ int n_split;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Extract split lengths.
+ * \param transform_steps An array record all transform steps.
+ * \param lengths The multiple split factors. Can be None to be filled by
search policy.
+ */
+ void ExtractSplitLengths(const Array<Step>& transform_steps,
+ Array<Optional<Integer>>* lengths) const;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
+ */
+ Array<Iterator> ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param transform_steps An array record all transform steps.
Review comment:
```suggestion
* \param stages The `te::Stage`s used in TVM scheduler applying.
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \param transform_steps An array record all transform steps.
* \return The iterator results after split.
```
##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -487,6 +490,164 @@ class SplitStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
};
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+ /*! \brief The id of the iter to be split. */
+ int iter_id;
+ /*! \brief The index of the split step to follow in the history. */
+ int src_step_id;
+ /*! \brief The number of split level. */
+ int n_split;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Extract split lengths.
+ * \param transform_steps An array record all transform steps.
+ * \param lengths The multiple split factors. Can be None to be filled by
search policy.
+ */
+ void ExtractSplitLengths(const Array<Step>& transform_steps,
+ Array<Optional<Integer>>* lengths) const;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
+ */
+ Array<Iterator> ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param transform_steps An array record all transform steps.
+ */
+ Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps)
const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param transform_steps An array record all transform steps.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap*
stage_to_axes,
+ const Array<Step>& transform_steps) const;
+
+ static constexpr const char* record_prefix_str = "FSP";
+
+ static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowSplitStepNode.
+ * \sa FollowSplitStepNode
+ */
+class FollowSplitStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be split.
+ * \param iter_id The index of the iterator to be split.
+ * \param src_step_id The index of the split step to follow in the history.
+ * \param n_split The number of split level.
+ */
+ FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
+
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and
create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit FollowSplitStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode);
+};
+
+/*! \brief Similar to FollowSplitStep, but uses split factors from multiple
steps.
+ * \note This can be used for the split in cooperative fetching.
+ */
+class FollowFusedSplitStepNode : public StepNode {
+ public:
+ /*! \brief The id of the iter to split. */
+ int iter_id;
+ /*! \brief The indices of the split steps to follow in the history. */
+ Array<Integer> src_step_ids;
+ /*! \brief Use the length in this split level. */
+ int level;
+ /*! \brief If this is true, use factor. Otherwise, use nparts. */
+ bool factor_or_nparts;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Extract split length.
+ * \param transform_steps An array record all transform steps.
+ * \return Split factor.
+ */
+ Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps)
const;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
+ */
+ Array<Iterator> ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param transform_steps An array record all transform steps.
+ */
+ Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps)
const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
Review comment:
```suggestion
* \param stages The `te::Stage`s used in TVM scheduler applying.
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
```
##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -487,6 +490,164 @@ class SplitStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
};
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+ /*! \brief The id of the iter to be split. */
+ int iter_id;
+ /*! \brief The index of the split step to follow in the history. */
+ int src_step_id;
+ /*! \brief The number of split level. */
+ int n_split;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Extract split lengths.
+ * \param transform_steps An array record all transform steps.
+ * \param lengths The multiple split factors. Can be None to be filled by
search policy.
+ */
+ void ExtractSplitLengths(const Array<Step>& transform_steps,
+ Array<Optional<Integer>>* lengths) const;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
+ */
+ Array<Iterator> ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param transform_steps An array record all transform steps.
+ */
+ Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps)
const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param transform_steps An array record all transform steps.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap*
stage_to_axes,
+ const Array<Step>& transform_steps) const;
+
+ static constexpr const char* record_prefix_str = "FSP";
+
+ static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowSplitStepNode.
+ * \sa FollowSplitStepNode
+ */
+class FollowSplitStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be split.
+ * \param iter_id The index of the iterator to be split.
+ * \param src_step_id The index of the split step to follow in the history.
+ * \param n_split The number of split level.
+ */
+ FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
+
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and
create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit FollowSplitStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode);
+};
+
+/*! \brief Similar to FollowSplitStep, but uses split factors from multiple
steps.
+ * \note This can be used for the split in cooperative fetching.
+ */
+class FollowFusedSplitStepNode : public StepNode {
+ public:
+ /*! \brief The id of the iter to split. */
+ int iter_id;
+ /*! \brief The indices of the split steps to follow in the history. */
+ Array<Integer> src_step_ids;
+ /*! \brief Use the length in this split level. */
+ int level;
+ /*! \brief If this is true, use factor. Otherwise, use nparts. */
+ bool factor_or_nparts;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Extract split length.
+ * \param transform_steps An array record all transform steps.
+ * \return Split factor.
+ */
+ Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps)
const;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
+ */
+ Array<Iterator> ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param transform_steps An array record all transform steps.
Review comment:
```suggestion
* \param stages The `te::Stage`s used in TVM scheduler applying.
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \param transform_steps An array record all transform steps.
* \return The iterator results after split.
```
##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -487,6 +490,164 @@ class SplitStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
};
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+ /*! \brief The id of the iter to be split. */
+ int iter_id;
+ /*! \brief The index of the split step to follow in the history. */
+ int src_step_id;
+ /*! \brief The number of split level. */
+ int n_split;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Extract split lengths.
+ * \param transform_steps An array record all transform steps.
+ * \param lengths The multiple split factors. Can be None to be filled by
search policy.
+ */
+ void ExtractSplitLengths(const Array<Step>& transform_steps,
+ Array<Optional<Integer>>* lengths) const;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
+ */
+ Array<Iterator> ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param transform_steps An array record all transform steps.
+ */
+ Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps)
const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
Review comment:
```suggestion
* \param stages The `te::Stage`s used in TVM scheduler applying.
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
```
##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -487,6 +490,164 @@ class SplitStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
};
+/*! \brief Similar to SplitStepNode, but uses split factors from another step
+ * (i.e. Follow another split step) */
+class FollowSplitStepNode : public StepNode {
+ public:
+ /*! \brief The id of the iter to be split. */
+ int iter_id;
+ /*! \brief The index of the split step to follow in the history. */
+ int src_step_id;
+ /*! \brief The number of split level. */
+ int n_split;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Extract split lengths.
+ * \param transform_steps An array record all transform steps.
+ * \param lengths The multiple split factors. Can be None to be filled by
search policy.
+ */
+ void ExtractSplitLengths(const Array<Step>& transform_steps,
+ Array<Optional<Integer>>* lengths) const;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
+ */
+ Array<Iterator> ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param transform_steps An array record all transform steps.
+ */
+ Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes,
+ const Array<Step>& transform_steps)
const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \param transform_steps An array record all transform steps.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap*
stage_to_axes,
+ const Array<Step>& transform_steps) const;
+
+ static constexpr const char* record_prefix_str = "FSP";
+
+ static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FollowSplitStepNode.
+ * \sa FollowSplitStepNode
+ */
+class FollowSplitStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be split.
+ * \param iter_id The index of the iterator to be split.
+ * \param src_step_id The index of the split step to follow in the history.
+ * \param n_split The number of split level.
+ */
+ FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);
+
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and
create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit FollowSplitStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode);
+};
+
+/*! \brief Similar to FollowSplitStep, but uses split factors from multiple
steps.
+ * \note This can be used for the split in cooperative fetching.
+ */
+class FollowFusedSplitStepNode : public StepNode {
+ public:
+ /*! \brief The id of the iter to split. */
+ int iter_id;
+ /*! \brief The indices of the split steps to follow in the history. */
+ Array<Integer> src_step_ids;
+ /*! \brief Use the length in this split level. */
+ int level;
+ /*! \brief If this is true, use factor. Otherwise, use nparts. */
+ bool factor_or_nparts;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Extract split length.
+ * \param transform_steps An array record all transform steps.
+ * \return Split factor.
+ */
+ Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps)
const;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
Review comment:
```suggestion
* \param state A mutable pointer to state, which will be updated.
* \return The iterator results after split.
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]