MasterJH5574 commented on a change in pull request #10066:
URL: https://github.com/apache/tvm/pull/10066#discussion_r795060478



##########
File path: src/tir/transforms/inject_software_pipeline.cc
##########
@@ -0,0 +1,785 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file inject_software_pipeline.cc
+ * \brief Transform annotated loops into pipelined one that parallelize 
producers and consumers
+ */
+#include <tvm/target/target.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/transform.h>
+
+#include "../../support/utils.h"
+#include "../schedule/utils.h"
+#include "./ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+namespace software_pipeline {
+
+/*!
+ * \brief Create a block and infer the access region with the given body.
+ *
+ * The result is a opaque block that doesn't contain any block iter vars. In 
case the body is a
+ * block realize without predicate, it is unnecessary to create a new block, 
the block of the block
+ * realize will be returned.
+ *
+ * \param body The body of the block.
+ * \param buffer_data_to_buffer The map from buffer data to buffer.
+ * \return The result block.
+ */
+Block MakeBlock(const Stmt& body, const Map<Var, Buffer>& 
buffer_data_to_buffer) {
+  if (const BlockRealizeNode* block_realize = body.as<BlockRealizeNode>()) {
+    if (is_one(block_realize->predicate)) {
+      // no need to create a new block
+      return block_realize->block;
+    }
+  }
+  Block block = Block({}, {}, {}, "", body);

Review comment:
       Would you like to annotate the arguments to make it clearer?

##########
File path: include/tvm/tir/transform.h
##########
@@ -492,6 +492,112 @@ TVM_DLL Pass ConvertForLoopsToSerial();
  */
 TVM_DLL Pass UnifiedStaticMemoryPlanner();
 
+/*!
+ * \brief Transform annotated loops into pipelined one that ovarlaps producers 
and consumers.
+ *
+ * This pass detects loops with the software pipeline annotations and rewrite 
them to pipelined
+ * ones. The behavior of such rewriting depending on two annotations on the 
loop,
+ * attr::software_pipeline_stage, and attr::software_pipeline_order, which 
defines the stage and the
+ * order, respectively, of the components of the software pipeline. The 
components of the software
+ * pipeline is the direct children (ignoring BlockRealize / Block / SeqStmt) 
of the annotated loop.
+ * The value of the both annotations should be array of integers, with its 
size the same as the
+ * number of the components.
+ *
+ * The result of the rewriting is a block that has three blocks as its direct 
children which
+ * represents the prologue, the body, and the epilogue of the software 
pipeline. In the prologue,
+ * only components whose stage is less than max_stage will be executed. In the 
epilogue, only
+ * components whose stage is greater than 0 will be executed. In the body, all 
the components will
+ * be executed. Such rewriting enables behavior like prefetching, the 
components are not necessarily
+ * executed in the original order. attr::software_pipeline_order defines the 
order of the each
+ * component. Components belong to different stages can be reordered.
+ *
+ * Nested software pipelines are allowed. In this case, the inner software 
pipeline will be
+ * generated first. As a result, this may affect the number of components, 
i.e. the number of the
+ * direct children of the outer loop. In this case, the annotations for the 
outer software
+ * pipeline should include the result of the inner software pipeline, which is 
three blocks as
+ * discussed above.
+ *
+ * Buffer allocated inside the software pipeline may be resized to accommodate 
multiple versions
+ * of the original buffer. Block annotation attr::double_buffer_scope can be 
used to indicate that
+ * the block need to write in the double-buffering style.
+ *
+ * The following annotations are used to specify the behavior of this pass:
+ *     attr::software_pipeline_stage: Array of non-negative integers, each 
element should be in
+ *                                    range [0, max_stage], where max_stage is 
the maximum
+ *                                    (inclusive) stage.

Review comment:
       Thanks @vinx13, I love this comment! Though I'm still curious about the 
meanings of the `attr::software_pipeline_stage` values.
   
   In the example below, the stage array is `[0, 1]`. But what if the array is 
`[1, 0]`, `[0, 0]` or `[1, 1]`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to