jcf94 commented on a change in pull request #6107:
URL: https://github.com/apache/incubator-tvm/pull/6107#discussion_r458568998
##########
File path: python/tvm/auto_scheduler/loop_state.py
##########
@@ -351,6 +351,68 @@ def compute_root(self, stage):
self.state_object = _ffi_api.StateComputeRoot(self.state_object,
self._resolve_stage_id(stage))
+ def cache_read(self, stage, scope_name, reader_stages):
+ """ Schedule primitive corresponds to te.schedule.cache_read.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be cache read, which can be specified by the integer
index, Operation,
+ or output tensor of the stage.
+ scope_name : str
+ The scope name to be set for the new added read stage.
+ reader_stages : List[Union[int, Operation, Tensor]]
+ The reader stages. Each of the list can be specified by the
integer index, Operation,
+ or output tensor of the stage.
+
+ Returns
+ -------
+ new_stage_op : Operator
+ The Operator of the new added stage.
+
+ Notes
+ -----
+ Cache read step will add an extra stage to the original ComputeDAG.
+ """
+ if isinstance(reader_stages, list):
+ reader_stage_ids = [self._resolve_stage_id(id) for id in
reader_stages]
+ else:
+ raise ValueError("reader_stages must be a list of the integer
index, Operation, " + \
+ "or output tensor of the stage")
+
+ self.state_object, new_stage_id =
_ffi_api.StateCacheRead(self.state_object,
+
self._resolve_stage_id(stage),
+ scope_name,
reader_stage_ids,
+
self.compute_dag)
+ return self._insert_new_stage(int(new_stage_id))
+
+ def cache_write(self, stage, scope_name):
+ """ Schedule primitive corresponds to te.schedule.cache_write.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be cache write, which can be specified by the integer
index, Operation,
+ or output tensor of the stage.
+ scope_name : str
+ The scope name to be set for the new added write stage.
+
+ Returns
+ -------
+ new_stage_op : Operator
+ The Operator of the new added stage.
+
+ Notes
+ -----
+ Cache write step will add an extra stage to the original ComputeDAG, a
up-to-date
+ ComputeDAG is stored in State's `current_compute_dag`.
+ This step will cache write all output tensors of the target stage.
Review comment:
The `current_compute_dag` should be a class member only used in C++.
Removed this from the python doc.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]