junrushao1994 commented on a change in pull request #10066:
URL: https://github.com/apache/tvm/pull/10066#discussion_r797128964



##########
File path: include/tvm/tir/transform.h
##########
@@ -492,6 +492,112 @@ TVM_DLL Pass ConvertForLoopsToSerial();
  */
 TVM_DLL Pass UnifiedStaticMemoryPlanner();
 
+/*!
+ * \brief Transform annotated loops into pipelined one that ovarlaps producers 
and consumers.
+ *
+ * This pass detects loops with the software pipeline annotations and rewrite 
them to pipelined
+ * ones. The behavior of such rewriting depending on two annotations on the 
loop,
+ * attr::software_pipeline_stage, and attr::software_pipeline_order, which 
defines the stage and the
+ * order, respectively, of the components of the software pipeline. The 
components of the software
+ * pipeline is the direct children (ignoring BlockRealize / Block / SeqStmt) 
of the annotated loop.
+ * The value of the both annotations should be array of integers, with its 
size the same as the
+ * number of the components.
+ *
+ * The result of the rewriting is a block that has three blocks as its direct 
children which
+ * represents the prologue, the body, and the epilogue of the software 
pipeline. In the prologue,
+ * only components whose stage is less than max_stage will be executed. In the 
epilogue, only
+ * components whose stage is greater than 0 will be executed. In the body, all 
the components will
+ * be executed. Such rewriting enables behavior like prefetching, the 
components are not necessarily
+ * executed in the original order. attr::software_pipeline_order defines the 
order of the each
+ * component. Components belong to different stages can be reordered.
+ *
+ * Nested software pipelines are allowed. In this case, the inner software 
pipeline will be
+ * generated first. As a result, this may affect the number of components, 
i.e. the number of the
+ * direct children of the outer loop. In this case, the annotations for the 
outer software
+ * pipeline should include the result of the inner software pipeline, which is 
three blocks as
+ * discussed above.
+ *
+ * Buffer allocated inside the software pipeline may be resized to accommodate 
multiple versions
+ * of the original buffer. Block annotation attr::double_buffer_scope can be 
used to indicate that
+ * the block need to write in the double-buffering style.
+ *
+ * The following annotations are used to specify the behavior of this pass:
+ *     attr::software_pipeline_stage: Array of non-negative integers, each 
element should be in
+ *                                    range [0, max_stage], where max_stage is 
the maximum
+ *                                    (inclusive) stage.
+ *     attr::software_pipeline_order: Array of non-negative integers, should 
be a permutation of
+ *                                    [0, 1, ..., num_components - 1].
+ *     attr::double_buffer_scope: Integer index of the write regions of the 
block. Mark a buffer
+ *                                should be double-buffered during the 
software pipelining.
+ *
+ * Example:
+ *
+ * Before this pass, the TIR is:
+ *
+ * \code{.py}
+ * @T.prim_func
+ * def before_transform(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 
16), "float32"]) -> None:
+ *     for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
+ *         for i in T.serial(0, 16,
+ *                           annotations={"software_pipeline_stage": [0, 1],
+ *                                        "software_pipeline_order": [0, 1]}
+ *                          ):
+ *             with T.block():
+ *                 T.reads(A[tx, i])
+ *                 T.writes(C[tx, i])
+ *                 B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
+ *                 with T.block("B"):
+ *                     T.reads(A[tx, i])
+ *                     T.writes(B[tx, 0])
+ *                     B[tx, 0] = A[tx, i] * T.float32(2)
+ *                 with T.block("C"):
+ *                     T.reads(B[tx, 0])
+ *                     T.writes(C[tx, i])
+ *                     C[tx, i] = B[tx, 0] + T.float32(1)
+ * \endcode
+ *
+ * The TIR above annotate the loop as a two-stage pipeline, the components are 
not reordered.
+ * After this pass, the TIR is:

Review comment:
       ```suggestion
    * The TIR above annotates the loop as a two-stage pipeline with no 
reordering.
    * After applying this pass, the TIR is transformed into:
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to