merrymercy commented on a change in pull request #6107:
URL: https://github.com/apache/incubator-tvm/pull/6107#discussion_r459812208



##########
File path: src/auto_scheduler/transform_step.h
##########
@@ -659,6 +671,153 @@ class ComputeRootStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode);
 };
 
+/********** Primitives adding new stages **********/
+
+/*!
+ * \brief Cache read step that corresponds to te::Schedule::cache_read.
+ * \note Cache read step will add an extra stage to the original ComputeDAG, a 
up-to-date ComputeDAG
+ * is stored in State's `current_compute_dag`.
+ */
+class CacheReadStepNode : public StepNode {
+ public:
+  /*! \brief The scope name to be set for the new added read stage. (e.g. 
local, shared, global) */
+  String scope_name;
+  /*! \brief The indexes of reader stages. */
+  Array<Integer> reader_stage_ids;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   * \param dag The original ComputeDAG of this state.
+   * \return The index of the new added stage.
+   */
+  int ApplyToState(State* state, const ComputeDAG& dag) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A mutable pointer to a `te::Stage` Array.
+   * \param stage_to_axes A mutable pointer to a StageToAxesMap.
+   * \param schedule A mutable pointer to a te::Schedule.
+   * \return The output Tensor of the new added stage.
+   */
+  te::Tensor ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* 
stage_to_axes,
+                             te::Schedule* schedule) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A mutable pointer to a `te::Stage` Array.
+   * \param stage_to_axes A mutable pointer to a StageToAxesMap.
+   * \param schedule A mutable pointer to a te::Schedule.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* 
stage_to_axes,
+                          te::Schedule* schedule) const;
+
+  static constexpr const char* record_prefix_str = "CHR";
+
+  static constexpr const char* _type_key = "auto_scheduler.CacheReadStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to CacheReadStepNode.
+ * \sa CacheReadStepNode
+ */
+class CacheReadStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be cache read.
+   * \param scope_name The scope name to be set for the new added read stage.
+   * \param reader_stage_ids The indexes of reader stages.
+   */
+  CacheReadStep(int stage_id, String scope_name, const Array<Integer>& 
reader_stage_ids);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and 
create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit CacheReadStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(CacheReadStep, Step, CacheReadStepNode);
+};
+
+/*!
+ * \brief Cache write step that corresponds to te::Schedule::cache_write.
+ * \note Cache write step will add an extra stage to the original ComputeDAG, 
a up-to-date
+ * ComputeDAG is stored in State's `current_compute_dag`.
+ * This step will cache write all output tensors of the target stage.
+ */
+class CacheWriteStepNode : public StepNode {
+ public:
+  /*!
+   * \brief The scope name to be set for the new added write stage. (e.g. 
local, shared,

Review comment:
       ditto




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to