merrymercy commented on a change in pull request #6073:
URL: https://github.com/apache/incubator-tvm/pull/6073#discussion_r456228059
##########
File path: python/tvm/auto_scheduler/loop_state.py
##########
@@ -161,16 +202,116 @@ def fuse(self, stage, iters):
The Stage to be fused, can be a Stage order index, Stage operation
or stage
output tensor.
iters : List[Iterator]
- The iterators to be fused
+ The iterators to be fused.
+
+ Returns
+ -------
+ res_it : Iterator
+ The fused Iterator.
+ """
+ self.state_object, res = _ffi_api.StateFuse(self.state_object,
+
self._resolve_stage_id(stage), iters)
+ return res
+
+ def vectorize(self, stage, iterator):
+ """ Schedule primitive corresponds to te.vectorize.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be vectorized, can be a Stage order index, Stage
operation or stage
+ output tensor.
+ iterator : Iterator
+ The iterator to be vectorized.
Returns
-------
res_it : Iterator
- The fused Iterator
+ The vectorized Iterator.
"""
- stage_id = self._resolve_stage_id(stage)
+ self.state_object, res = _ffi_api.StateVectorize(self.state_object,
+
self._resolve_stage_id(stage), iterator)
+ return res
+
+ def parallel(self, stage, iterator):
+ """ Schedule primitive corresponds to te.parallel.
- self.state_object, res = _ffi_api.StateFuse(self.state_object,
stage_id, iters)
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be paralleled, can be a Stage order index, Stage
operation or stage
+ output tensor.
+ iterator : Iterator
+ The iterator to be paralleled.
+
+ Returns
+ -------
+ res_it : Iterator
+ The paralleled Iterator.
+ """
+ self.state_object, res = _ffi_api.StateParallel(self.state_object,
+
self._resolve_stage_id(stage), iterator)
+ return res
+
+ def unroll(self, stage, iterator, max_unroll=None):
+ """ Schedule primitive corresponds to te.unroll.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be unrolled, can be a Stage order index, Stage
operation or stage
+ output tensor.
+ iterator : Iterator
+ The iterator to be unrolled.
+ max_unroll : Optional[int]
+ The max unroll limit. Iterator with extent larger than this limit
will be skipped.
+
+ Returns
+ -------
+ res_it : Iterator
+ The unrolled Iterator.
+ """
+ self.state_object, res = _ffi_api.StateUnroll(self.state_object,
+
self._resolve_stage_id(stage), iterator,
+ max_unroll if max_unroll
else -1)
+ return res
+
+ def bind(self, stage, iterator, thread_name):
+ """ Schedule primitive corresponds to te.bind.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be binded, can be a Stage order index, Stage
operation or stage
+ output tensor.
+ iterator : Iterator
+ The iterator to be binded.
+ thread_name : str
+ The thread type to be binded. Currently support:
+ - vthread
+ - blockIdx.x
+ - threadIdx.x
+ - blockIdx.y
+ - threadIdx.y
Review comment:
We should add them. Because other policies might use them.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]