junrushao1994 commented on a change in pull request #10066: URL: https://github.com/apache/tvm/pull/10066#discussion_r797137808
########## File path: src/tir/transforms/inject_software_pipeline.cc ########## @@ -0,0 +1,785 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file inject_software_pipeline.cc + * \brief Transform annotated loops into pipelined one that parallelize producers and consumers + */ +#include <tvm/target/target.h> +#include <tvm/tir/builtin.h> +#include <tvm/tir/transform.h> + +#include "../../support/utils.h" +#include "../schedule/utils.h" +#include "./ir_utils.h" + +namespace tvm { +namespace tir { + +namespace software_pipeline { + +/*! + * \brief Create a block and infer the access region with the given body. + * + * The result is a opaque block that doesn't contain any block iter vars. In case the body is a + * block realize without predicate, it is unnecessary to create a new block, the block of the block + * realize will be returned. + * + * \param body The body of the block. + * \param buffer_data_to_buffer The map from buffer data to buffer. + * \return The result block. + */ +Block MakeBlock(const Stmt& body, const Map<Var, Buffer>& buffer_data_to_buffer) { + if (const BlockRealizeNode* block_realize = body.as<BlockRealizeNode>()) { + if (is_one(block_realize->predicate)) { + // no need to create a new block + return block_realize->block; + } + } + Block block = Block({}, {}, {}, "", body); + auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer); + auto* n = block.CopyOnWrite(); + n->reads = access[0]; + n->writes = access[1]; + return block; +} + +/*! Structure that represents the stage and order of the software pipeline component. */ +struct PipelineStageOrder { + int stage; + int order; + explicit PipelineStageOrder(int stage, int order) : stage(stage), order(order) {} +}; + +using PipelineInfo = std::unordered_map<Block, PipelineStageOrder, ObjectPtrHash, ObjectPtrEqual>; + +struct BufferAccessInfo { + int def; // the defining stage of the buffer + int use; // the last using stage of the buffer + explicit BufferAccessInfo(int def = -1, int use = -1) : def(def), use(use) {} +}; + +/*! + * \brief Rewriter for the body of the software pipeline. This pass inserts `floormod` to indices + * of the remapped buffer to select the version corresponding to the pipeline stage. + */ +class PipelineBodyRewriter : public StmtExprMutator { + public: + /*! + * \brief Constructor of PipelineBodyRewriter. + * \param buffer_data_to_buffer The map from buffer data to buffer. + * \param buffer_remap The map from original buffer to the buffer with updated shape for + * multi-versioning in the sofeware pipeline. + * \param pipeline_loop The original loop to be software pipelined. + * \param access_all_versions Whether all versions the the buffers in the software pipeline are + * accessed. This will be used to update block access region. In the prologue and epilogue + * of a two-stage software pipeline, only one version of these buffers are accessed. + * \param fragment_info Information about tensor core fragment + */ + PipelineBodyRewriter(const Map<Var, Buffer>& buffer_data_to_buffer, + const Map<Buffer, Buffer>& buffer_remap, For pipeline_loop, + bool access_all_versions, + const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info) + : buffer_data_to_buffer_(buffer_data_to_buffer), + buffer_remap_(buffer_remap), + pipeline_loop_(pipeline_loop), + access_all_versions_(access_all_versions), + fragment_info_(fragment_info) {} + + private: + BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) const { + auto it = buffer_remap_.find(buffer_region->buffer); + if (it != buffer_remap_.end()) { + Region new_region = buffer_region->region; + const Buffer& new_buffer = (*it).second; + // For pipeline buffers, relax the access region of the first dimension to full extent + // if access_all_versions == true + Range accessed_version = + access_all_versions_ + ? Range::FromMinExtent(0, new_buffer->shape[0]) + : Range::FromMinExtent(floormod((pipeline_loop_->loop_var - pipeline_loop_->min), + new_buffer->shape[0]), + Integer(1)); + new_region.insert(new_region.begin(), accessed_version); + return BufferRegion(new_buffer, new_region); + } + return buffer_region; + } + + Stmt VisitStmt_(const BlockNode* op) final { + for (const Buffer& alloc_buffer : op->alloc_buffers) { + buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer); + } + Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op)); + BlockNode* n = block.CopyOnWrite(); + n->reads.MutateByApply( + std::bind(&PipelineBodyRewriter::RewritePipelineBufferRegion, this, std::placeholders::_1)); + n->writes.MutateByApply( + std::bind(&PipelineBodyRewriter::RewritePipelineBufferRegion, this, std::placeholders::_1)); + for (const Buffer& alloc_buffer : op->alloc_buffers) { + buffer_data_to_buffer_.erase(alloc_buffer->data); + } + return std::move(block); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); + auto it = buffer_remap_.find(store->buffer); + if (it == buffer_remap_.end()) { + return std::move(store); + } + const Buffer& new_buffer = (*it).second; + auto* n = store.CopyOnWrite(); + n->buffer = new_buffer; + PrimExpr version = + floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); + n->indices.insert(n->indices.begin(), version); + return std::move(store); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); + auto it = buffer_remap_.find(load->buffer); + if (it == buffer_remap_.end()) { + return std::move(load); + } + const Buffer& new_buffer = (*it).second; + auto* n = load.CopyOnWrite(); + n->buffer = new_buffer; + PrimExpr version = + floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); + n->indices.insert(n->indices.begin(), version); + return std::move(load); + } + + int GetWmmaFragmentSize(const Buffer& buffer) { + auto it = fragment_info_.find(buffer->data.get()); + ICHECK(it != fragment_info_.end()); + const FragmentInfo& info = (*it).second; + String scope = buffer.scope(); + if (scope == "wmma.matrix_a") { + return info.m * info.k; + } else if (scope == "wmma.matrix_b") { + return info.n * info.k; + } else if (scope == "wmma.accumulator") { + return info.m * info.n; + } else { + ICHECK(0); + throw; + } + } + + PrimExpr RewriteWmmaFragmentIndex(const Buffer& old_buffer, const Buffer& new_buffer, + const PrimExpr& old_index) { + PrimExpr new_buffer_offset = old_index; + + int fragment_size = GetWmmaFragmentSize(old_buffer); + PrimExpr offset = + floordiv(foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, + make_const(DataType::Int(32), 1), old_buffer->shape), + fragment_size); + new_buffer_offset += + floormod(pipeline_loop_->loop_var - pipeline_loop_->min, new_buffer->shape[0]) * offset; + return new_buffer_offset; + } + + PrimExpr VisitExpr_(const CallNode* op) final { + // Intrinsic calls should be handled explicitly here as they are opaque accesses to + // buffer. + static const auto& load_matrix_sync = builtin::tvm_load_matrix_sync(); + static const auto& store_matrix_sync = builtin::tvm_store_matrix_sync(); + static const auto& mma_sync = builtin::tvm_mma_sync(); + static const auto& access_ptr = builtin::tvm_access_ptr(); + Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op)); + if (call->op.same_as(load_matrix_sync) || call->op.same_as(store_matrix_sync)) { + const Buffer& buffer = buffer_data_to_buffer_.at(Downcast<Var>(call->args[0])); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + Array<PrimExpr> new_args = call->args; + const Buffer& new_buffer = (*it).second; + new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer, call->args[4])); + return Call(call->dtype, call->op, new_args, call->span); + } + } else if (call->op.same_as(mma_sync)) { + Array<PrimExpr> new_args = call->args; + for (int i = 0; i < 4; i++) { + const Var& buffer_var = Downcast<Var>(call->args[i * 2]); + const PrimExpr& index = call->args[i * 2 + 1]; + const Buffer& buffer = buffer_data_to_buffer_.at(buffer_var); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + PrimExpr new_index = RewriteWmmaFragmentIndex(buffer, (*it).second, index); + new_args.Set(i * 2 + 1, new_index); + } + } + return Call(call->dtype, call->op, new_args, call->span); + } else if (call->op.same_as(access_ptr)) { + const Buffer& buffer = buffer_data_to_buffer_.at(Downcast<Var>(call->args[1])); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + Array<PrimExpr> new_args = call->args; + const Buffer& new_buffer = (*it).second; + const PrimExpr& old_index = call->args[2]; + PrimExpr offset; + if (new_buffer->strides.empty()) { + offset = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, + make_const(DataType::Int(32), 1), buffer->shape); + } else { + offset = new_buffer->strides[0]; + } + PrimExpr new_index = old_index + floormod(pipeline_loop_->loop_var, 2) * offset; + new_args.Set(2, new_index); + return Call(call->dtype, call->op, new_args, call->span); + } + } + return std::move(call); + } + + Map<Var, Buffer> buffer_data_to_buffer_; + Map<Buffer, Buffer> buffer_remap_; + For pipeline_loop_; + bool access_all_versions_; + const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info_; +}; + +/*! + * \brief Rewriter for the software pipeline that rewrite a loop into a pipelined one. + */ +class PipelineRewriter : public StmtExprMutator { + public: + static Stmt Rewrite( + Map<Var, Buffer> buffer_data_to_buffer, + const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& double_buffers, + const Array<Buffer> pipeline_allocs, const For& pipeline_loop, + const PipelineInfo& pipeline_info, + const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info) { + PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop, + pipeline_info, fragment_info); + return rewriter.BuildPipeline(); + } + + private: + PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer, + const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& double_buffers, + const Array<Buffer>& pipeline_allocs, const For& pipeline_loop, + const PipelineInfo& pipeline_info, + const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info) + + : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), + double_buffers_(double_buffers), + pipeline_allocs_(pipeline_allocs), + pipeline_loop_(pipeline_loop), + pipeline_info_(pipeline_info), + fragment_info_(fragment_info) {} + + Stmt BuildPipeline() { + // Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions + // need to maintain for each buffer. + RemapPipelineBuffers(pipeline_allocs_); + + ordered_stmts_.resize(pipeline_info_.size()); + for (const auto& pair : pipeline_info_) { + const Block& block = pair.first; + int order = pair.second.order; + ordered_stmts_.Set(order, block); + } + + // Step 2: Emit the pipeline prologue, body and epilogue. + Stmt prologue = EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true); + Stmt body = EmitImpl(pipeline_loop_->min + max_stage_, + pipeline_loop_->min + pipeline_loop_->extent, false); + Stmt epilogue = EmitImpl(pipeline_loop_->min + pipeline_loop_->extent, + pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true); + + SeqStmt stmt = SeqStmt({prologue, body, epilogue}); + + // Step 3: Make a new block that contains new buffer allocations after pipeline rewriting. + Array<Buffer> alloc_buffers; + for (const auto& alloc : pipeline_allocs_) { + auto it = buffer_remap_.find(alloc); + if (it != buffer_remap_.end()) { + alloc_buffers.push_back((*it).second); + } else { + alloc_buffers.push_back(alloc); + } + buffer_data_to_buffer_.erase(alloc->data); + } + Block block = MakeBlock(stmt, buffer_data_to_buffer_); + auto* n = block.CopyOnWrite(); + n->alloc_buffers = std::move(alloc_buffers); + return BlockRealize({}, Bool(true), block); + } + + private: + /*! + * \brief Analyze accesses to the buffers in the software pipeline. + * + * This method check the 'define' and 'use' stage of the buffers in the software pipeline, which + * can be used to compute the number of versions needed to maintain after rewriting. + */ + std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> + GetBufferAccessInfo() { + std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> infos; + for (const auto& pair : pipeline_info_) { + const Block& block = pair.first; + int stage = pair.second.stage; + max_stage_ = std::max(max_stage_, stage); + + for (const BufferRegion& write : block->writes) { + if (!infos.count(write->buffer)) { + infos.emplace(write->buffer, BufferAccessInfo{}); + } + auto& info = infos.at(write->buffer); + if (info.def == -1) { + info.def = stage; + } else { + info.def = std::min(info.def, stage); + } + } + + for (const BufferRegion& read : block->reads) { + if (!infos.count(read->buffer)) { + infos.emplace(read->buffer, BufferAccessInfo{}); + } + auto& info = infos.at(read->buffer); + info.use = std::max(info.use, stage); + } + } + return infos; + } + + /*! + * \brief Check whether two regions have intersections. + * \param region1 The first region. + * \param region2 The second region. + * \return Whether region1 and region2 have intersections. + */ + bool MayConflict(Region region1, Region region2) { + ICHECK(region1.size() == region2.size()); + for (size_t i = 0; i < region1.size(); i++) { + Range dim1 = region1[i]; + Range dim2 = region2[i]; + auto int_set1 = arith::IntSet::FromRange(dim1); + auto int_set2 = arith::IntSet::FromRange(dim2); + if (arith::Intersect({int_set1, int_set2}).IsNothing()) { + return false; + } + } + return true; + } + + /*! + * \brief Compute the number of versions need to maintain for buffer accessed in the software + * pipeline. + * + * This method applies liveness analysis to the target buffer to compute the number of versions + * need to maintain during the software pipeline. + * Annotation `attr::double_buffer_scope` is handled here which provides a way to override the + * result of the analysis. Additional double buffering in the software pipeline can be useful + * to eliminate synchonizations in GPU devices. Review comment: ```suggestion * to eliminate synchronizations in GPU devices. ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
