junrushao1994 commented on a change in pull request #10066:
URL: https://github.com/apache/tvm/pull/10066#discussion_r800253537



##########
File path: src/tir/transforms/inject_software_pipeline.cc
##########
@@ -0,0 +1,785 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file inject_software_pipeline.cc
+ * \brief Transform annotated loops into pipelined one that parallelize 
producers and consumers
+ */
+#include <tvm/target/target.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/transform.h>
+
+#include "../../support/utils.h"
+#include "../schedule/utils.h"
+#include "./ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+namespace software_pipeline {
+
+/*!
+ * \brief Create a block and infer the access region with the given body.
+ *
+ * The result is a opaque block that doesn't contain any block iter vars. In 
case the body is a
+ * block realize without predicate, it is unnecessary to create a new block, 
the block of the block
+ * realize will be returned.
+ *
+ * \param body The body of the block.
+ * \param buffer_data_to_buffer The map from buffer data to buffer.
+ * \return The result block.
+ */
+Block MakeBlock(const Stmt& body, const Map<Var, Buffer>& 
buffer_data_to_buffer) {
+  if (const BlockRealizeNode* block_realize = body.as<BlockRealizeNode>()) {
+    if (is_one(block_realize->predicate)) {
+      // no need to create a new block
+      return block_realize->block;
+    }
+  }
+  Block block = Block({}, {}, {}, "", body);
+  auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer);
+  auto* n = block.CopyOnWrite();
+  n->reads = access[0];
+  n->writes = access[1];
+  return block;
+}
+
+/*! Structure that represents the stage and order of the software pipeline 
component. */
+struct PipelineStageOrder {
+  int stage;
+  int order;
+  explicit PipelineStageOrder(int stage, int order) : stage(stage), 
order(order) {}
+};
+
+using PipelineInfo = std::unordered_map<Block, PipelineStageOrder, 
ObjectPtrHash, ObjectPtrEqual>;
+
+struct BufferAccessInfo {
+  int def;  // the defining stage of the buffer
+  int use;  // the last using stage of the buffer
+  explicit BufferAccessInfo(int def = -1, int use = -1) : def(def), use(use) {}
+};
+
+/*!
+ * \brief Rewriter for the body of the software pipeline. This pass inserts 
`floormod` to indices
+ * of the remapped buffer to select the version corresponding to the pipeline 
stage.
+ */
+class PipelineBodyRewriter : public StmtExprMutator {
+ public:
+  /*!
+   * \brief Constructor of PipelineBodyRewriter.
+   * \param buffer_data_to_buffer The map from buffer data to buffer.
+   * \param buffer_remap The map from original buffer to the buffer with 
updated shape for
+   *        multi-versioning in the sofeware pipeline.
+   * \param pipeline_loop The original loop to be software pipelined.
+   * \param access_all_versions Whether all versions the the buffers in the 
software pipeline are
+   *        accessed. This will be used to update block access region. In the 
prologue and epilogue
+   *        of a two-stage software pipeline, only one version of these 
buffers are accessed.
+   * \param fragment_info Information about tensor core fragment
+   */
+  PipelineBodyRewriter(const Map<Var, Buffer>& buffer_data_to_buffer,
+                       const Map<Buffer, Buffer>& buffer_remap, For 
pipeline_loop,
+                       bool access_all_versions,
+                       const std::unordered_map<const VarNode*, FragmentInfo>& 
fragment_info)
+      : buffer_data_to_buffer_(buffer_data_to_buffer),
+        buffer_remap_(buffer_remap),
+        pipeline_loop_(pipeline_loop),
+        access_all_versions_(access_all_versions),
+        fragment_info_(fragment_info) {}
+
+ private:
+  BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) 
const {
+    auto it = buffer_remap_.find(buffer_region->buffer);
+    if (it != buffer_remap_.end()) {
+      Region new_region = buffer_region->region;
+      const Buffer& new_buffer = (*it).second;
+      // For pipeline buffers, relax the access region of the first dimension 
to full extent
+      // if access_all_versions == true
+      Range accessed_version =
+          access_all_versions_
+              ? Range::FromMinExtent(0, new_buffer->shape[0])
+              : Range::FromMinExtent(floormod((pipeline_loop_->loop_var - 
pipeline_loop_->min),
+                                              new_buffer->shape[0]),
+                                     Integer(1));
+      new_region.insert(new_region.begin(), accessed_version);
+      return BufferRegion(new_buffer, new_region);
+    }
+    return buffer_region;
+  }
+
+  Stmt VisitStmt_(const BlockNode* op) final {
+    for (const Buffer& alloc_buffer : op->alloc_buffers) {
+      buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
+    }
+    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
+    BlockNode* n = block.CopyOnWrite();
+    n->reads.MutateByApply(
+        std::bind(&PipelineBodyRewriter::RewritePipelineBufferRegion, this, 
std::placeholders::_1));
+    n->writes.MutateByApply(
+        std::bind(&PipelineBodyRewriter::RewritePipelineBufferRegion, this, 
std::placeholders::_1));
+    for (const Buffer& alloc_buffer : op->alloc_buffers) {
+      buffer_data_to_buffer_.erase(alloc_buffer->data);
+    }
+    return std::move(block);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+    auto it = buffer_remap_.find(store->buffer);
+    if (it == buffer_remap_.end()) {
+      return std::move(store);
+    }
+    const Buffer& new_buffer = (*it).second;
+    auto* n = store.CopyOnWrite();
+    n->buffer = new_buffer;
+    PrimExpr version =
+        floormod((pipeline_loop_->loop_var - pipeline_loop_->min), 
new_buffer->shape[0]);
+    n->indices.insert(n->indices.begin(), version);
+    return std::move(store);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+    auto it = buffer_remap_.find(load->buffer);
+    if (it == buffer_remap_.end()) {
+      return std::move(load);
+    }
+    const Buffer& new_buffer = (*it).second;
+    auto* n = load.CopyOnWrite();
+    n->buffer = new_buffer;
+    PrimExpr version =
+        floormod((pipeline_loop_->loop_var - pipeline_loop_->min), 
new_buffer->shape[0]);
+    n->indices.insert(n->indices.begin(), version);
+    return std::move(load);
+  }
+
+  int GetWmmaFragmentSize(const Buffer& buffer) {
+    auto it = fragment_info_.find(buffer->data.get());
+    ICHECK(it != fragment_info_.end());
+    const FragmentInfo& info = (*it).second;
+    String scope = buffer.scope();
+    if (scope == "wmma.matrix_a") {
+      return info.m * info.k;
+    } else if (scope == "wmma.matrix_b") {
+      return info.n * info.k;
+    } else if (scope == "wmma.accumulator") {
+      return info.m * info.n;
+    } else {
+      ICHECK(0);
+      throw;
+    }
+  }
+
+  PrimExpr RewriteWmmaFragmentIndex(const Buffer& old_buffer, const Buffer& 
new_buffer,
+                                    const PrimExpr& old_index) {
+    PrimExpr new_buffer_offset = old_index;
+
+    int fragment_size = GetWmmaFragmentSize(old_buffer);
+    PrimExpr offset =
+        floordiv(foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, 
b, span); },
+                       make_const(DataType::Int(32), 1), old_buffer->shape),
+                 fragment_size);
+    new_buffer_offset +=
+        floormod(pipeline_loop_->loop_var - pipeline_loop_->min, 
new_buffer->shape[0]) * offset;
+    return new_buffer_offset;
+  }
+
+  PrimExpr VisitExpr_(const CallNode* op) final {
+    // Intrinsic calls should be handled explicitly here as they are opaque 
accesses to
+    // buffer.
+    static const auto& load_matrix_sync = builtin::tvm_load_matrix_sync();

Review comment:
       Looks like we need to handle a few opaque intrinsics to make our 
analysis possible. Given the set of intrinsics could expand, do you think it's 
possible to generalize this mechanism and make the logic more clear? 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to