wrongtest-intellif commented on code in PR #17133:
URL: https://github.com/apache/tvm/pull/17133#discussion_r1705044956
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -527,96 +806,805 @@ Stmt TransformReductionBlock(const BlockRealizeNode*
realize, //
}
/*!
- * \brief Detect cross-thread reduction pattern and then transform
+ * \brief Inject the lowered allreduce block transformed from the input
reduction block
+ * \param realize The block-realize which contains the old reduction block
+ * \param ct_buffers The buffers to store cross-thread reduction results
+ * \param wb_buffers The buffers to store the final reduction results
+ * \param old_wb_indices The indices used to access the write-back buffers
when storing the final
+ * reduction results into the write-back buffers
+ * \param reducer The reduction function
+ * \param combiner_lhs The LHS values of the combiner
+ * \param reduction_loops The reduction loops
*/
-class CrossThreadReductionTransformer : public StmtMutator {
- private:
- // Check if the input block needs cross-thread reduction.
- std::vector<const ForNode*> NeedCrossThreadReduction(const BlockRealizeNode*
realize) {
- // Step 0. If the block is the root block, just return.
- if (block_stack_.empty()) {
- return {};
- }
+Stmt InjectReductionBlock(const BlockRealizeNode* realize,
//
Review Comment:
some empty trailing comments left
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -527,96 +806,805 @@ Stmt TransformReductionBlock(const BlockRealizeNode*
realize, //
}
/*!
- * \brief Detect cross-thread reduction pattern and then transform
+ * \brief Inject the lowered allreduce block transformed from the input
reduction block
+ * \param realize The block-realize which contains the old reduction block
+ * \param ct_buffers The buffers to store cross-thread reduction results
+ * \param wb_buffers The buffers to store the final reduction results
+ * \param old_wb_indices The indices used to access the write-back buffers
when storing the final
+ * reduction results into the write-back buffers
+ * \param reducer The reduction function
+ * \param combiner_lhs The LHS values of the combiner
+ * \param reduction_loops The reduction loops
*/
-class CrossThreadReductionTransformer : public StmtMutator {
- private:
- // Check if the input block needs cross-thread reduction.
- std::vector<const ForNode*> NeedCrossThreadReduction(const BlockRealizeNode*
realize) {
- // Step 0. If the block is the root block, just return.
- if (block_stack_.empty()) {
- return {};
- }
+Stmt InjectReductionBlock(const BlockRealizeNode* realize,
//
+ const Array<Buffer>& ct_buffers,
//
+ const Array<Buffer>& wb_buffers,
//
+ const Array<PrimExpr>& old_wb_indices,
//
+ const CommReducer& reducer,
//
+ const Array<PrimExpr>& combiner_lhs,
//
+ const std::vector<const ForNode*>& reduction_loops
//
+ ) {
+ int n_buffers = wb_buffers.size();
+ const BlockNode* block = realize->block.get();
- // Step 1. If the block is not a reduction block, cross-thread reduction
is not needed.
- if (!IsReductionBlock(GetRef<BlockRealize>(realize), loop_range_map_,
- GetRef<Block>(block_stack_.back()), &analyzer_)) {
- return {};
+ auto f_create_buffer_regions = [](Array<Buffer> buffers) {
+ Array<BufferRegion> regions;
+ regions.reserve(buffers.size());
+ for (const Buffer& buffer : buffers) {
+ regions.push_back(BufferRegion(buffer, {Range::FromMinExtent(0, 1)}));
}
+ return regions;
+ };
- // Step 2. Collect all the vars that appear in the bindings of reduction
block iters.
- std::unordered_set<const VarNode*> reduction_vars;
- GetVarsTouchedByBlockIters(GetRef<BlockRealize>(realize), nullptr,
&reduction_vars);
+ Array<BufferRegion> ct_buffer_regions = f_create_buffer_regions(ct_buffers);
+ Optional<Array<BufferRegion>> it_buffer_regions = NullOpt;
+ // In total, the block is transformed into at most 4 statements
+ // - Stmt 1: initialize the buffer for in-thread reduction
+ // - Stmt 2: do in-thread reduction
+ // - Stmt 3: do cross-thread reduction
+ // - Stmt 4: write cross-thread reduction result to the original buffer
+ Array<Stmt> stmts;
+ stmts.reserve(4);
- // Step 3. Collect the loops whose loop vars appear in the bindings of
reduction block iters.
- // We call these loops "reduction-related".
- // Step 4. See whether at least one reduction-related loop is bound to
thread axis in GPU - if
- // so, cross-thread reduction is needed. If none of the reduction-related
loops is bound to
- // thread axis, cross-thread reduction is not needed for the input block.
- bool need = false;
- std::vector<const ForNode*> reduction_loops;
- for (const ForNode* loop : loop_stack_) {
- if (reduction_vars.count(loop->loop_var.get())) {
- // Step 3. Collect the loop.
- reduction_loops.push_back(loop);
- // Step 4. See whether the loop is bound to some thread axis.
- if (loop->thread_binding.defined()) {
- need = true;
- }
+ // Stmt 3: do cross-thread reduction
+ {
+ // Step 3.1. Create the parameters to the intrinsic
+ Array<PrimExpr> parameters;
+ parameters.reserve(reduction_loops.size() + 4);
+ // 1-st argument: number of buffers
+ parameters.push_back(make_const(DataType::UInt(32), n_buffers));
+ // Next `n_buffers` arguments: sources
+ parameters.insert(parameters.end(), combiner_lhs.begin(),
combiner_lhs.end());
+ // Next argument: predicate
+ parameters.push_back(const_true());
+ // Next `n_buffers` arguments: destinations
+ for (int i = 0; i < n_buffers; ++i) {
+ parameters.push_back(BufferLoad(ct_buffers[i], {0}));
+ }
+ // Next arguments: all the reduction threads
+ for (const ForNode* reduction_loop : reduction_loops) {
+ if (reduction_loop->thread_binding.defined()) {
+ parameters.push_back(reduction_loop->loop_var);
}
}
- return need ? reduction_loops : std::vector<const ForNode*>{};
- }
-
- // Check if the input block needs thread broadcast rewrite.
- // One block needs broadcast rewrite when
- // 1. it consumes a buffer produced by cross-thread reduction under
- // the same kernel (i.e., same group of blockIdx),
- // 2. it writes to non-local memory,
- // 3. at least one of the reduction thread vars of the cross-thread reduction
- // is free to this block (i.e., not bound to the block).
- std::vector<std::pair<ThreadScope, Range>> NeedCrossThreadBroadcast(
- const BlockRealizeNode* realize) {
- Block block = realize->block;
-
- // If the block writes to local memory, no rewrite is needed.
- for (BufferRegion write_region : block->writes) {
- if (write_region->buffer.scope() == "local") {
- return {};
+ // Step 3.2. Create the block and the block-realize.
+ Array<IterVar> iter_vars = block->iter_vars;
+ Array<PrimExpr> bindings = realize->iter_values;
+ Array<BufferRegion> reads = block->writes;
+
+ // Blockized block should also be considered
+ if (HasChildBlocksChecker::Check(GetRef<Block>(block))) {
+ Array<BlockRealize> child_blocks =
+ HasChildBlocksChecker::GetChildBlockRealizes(GetRef<Block>(block));
+
+ ICHECK_GT(child_blocks.size(), 0) << "Child blocks should be more than
0";
+
+ // If has ChildBlocks, the reads should be analyzed from the child blocks
+ reads.clear();
+ for (BlockRealize child_block : child_blocks) {
+ Array<IterVar> child_iter_vars = child_block->block->iter_vars;
+ Array<PrimExpr> child_bindings = child_block->iter_values;
+ iter_vars.insert(iter_vars.end(), child_iter_vars.begin(),
child_iter_vars.end());
+ bindings.insert(bindings.end(), child_bindings.begin(),
child_bindings.end());
+ reads.insert(reads.end(), child_block->block->writes.begin(),
+ child_block->block->writes.end());
}
}
- // Find out the reduction threads for the read-buffers which are produced
by
- // cross-thread reduction.
- std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual>
thread2range;
- for (BufferRegion read_region : block->reads) {
- auto buf_it = crt_buf2threads_.find(read_region->buffer.get());
- if (buf_it == crt_buf2threads_.end()) {
- continue;
+ // Remove unused iter vars which introduced by blockize
+ // otherwise may generate duplicated for loops
+ Array<IterVar> iter_vars_used;
+ Array<PrimExpr> bindings_used;
+ auto f_inject = [&iter_vars, &bindings, &iter_vars_used,
+ &bindings_used](const VarNode* var) -> bool {
+ for (size_t i = 0; i < iter_vars.size(); ++i) {
+ const IterVar& iter_var = iter_vars[i];
+ if (iter_var->var.get() == var) {
+ if (std::find(iter_vars_used.begin(), iter_vars_used.end(),
iter_var) ==
+ iter_vars_used.end()) {
+ iter_vars_used.push_back(iter_var);
+ bindings_used.push_back(bindings[i]);
+ }
+ return true;
+ }
}
- for (auto [scope, range] : buf_it->second) {
- thread2range[scope] = range;
+ return false;
+ };
+ for (const BufferRegion& read : reads) {
+ for (const Range& range : read->region) {
+ UsedIterVarCollector::Collect(range->min, f_inject);
+ UsedIterVarCollector::Collect(range->extent, f_inject);
}
}
- // Erase those threads which are not free to this block.
- for (const ForNode* loop : loop_stack_) {
- if (loop->thread_binding.defined()) {
- ThreadScope scope =
ThreadScope::Create(loop->thread_binding.value()->thread_tag);
- thread2range.erase(scope);
+ Block cross_thread_block =
+ Block(/*iter_vars=*/std::move(iter_vars_used),
+ /*reads=*/std::move(reads),
+ /*writes=*/ct_buffer_regions,
+ /*name_hint=*/block->name_hint + "_cross_thread",
+ /*body=*/
+ AttrStmt(/*node=*/reducer,
+ /*attr_key=*/tir::attr::reduce_scope,
+ /*value=*/make_zero(DataType::Handle()),
+ /*body=*/
+ Evaluate(Call(/*dtype=*/DataType::Handle(),
+
/*op=*/tir::builtin::tvm_thread_allreduce(),
+ /*args=*/std::move(parameters)))));
+ ObjectPtr<BlockNode> cross_thread_block_node =
Review Comment:
could we initialize `kIsCrossThreadReductionApplied` just in the constructor?
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -155,6 +161,277 @@ Array<Buffer> MakeScratchpads(const Array<Buffer>&
reduction_buffers, bool is_cr
return new_buffers;
}
+/*!
+ * \brief Get init value from BufferStore Node
+ * \param block The block to be checked
+ * \return The init value
+*/
+class InitUpdateValueFinder : public StmtExprVisitor {
+ public:
+ /*!
+ * \brief Find the init value of the given block
+ * \param block The block to be checked
+ * \return The init value of the given block
+ */
+ static PrimExpr FindInit(const Block& block) {
+ InitUpdateValueFinder finder;
+ finder(block->body);
+ CHECK(finder.init_value_.defined()) << "The init value of the block is not
found";
+ return finder.init_value_;
+ }
+
+ /*!
+ * \brief Find the update value of the given block
+ * \param block The block to be checked
+ * \return The update value of the given block
+ */
+ static BufferStore FindUpdate(const Block& block) {
+ InitUpdateValueFinder finder;
+ finder(block->body);
+ CHECK(finder.update_value_.defined()) << "The update value of the block is
not found";
+ return finder.update_value_;
+ }
+
+ /*!
+ * \brief Check whether the input block has MMA operation
+ * \param realize The block to be checked
+ * \return A boolean indicating whether the input block has MMA operation.
+ */
+ static bool CheckHasMMA(const Block& block) {
+ InitUpdateValueFinder checker;
+ checker(block->body);
+ return checker.has_mma_;
+ }
+
+ private:
+ void VisitStmt_(const BufferStoreNode* node) final {
+ BufferStore store = GetRef<BufferStore>(node);
Review Comment:
do we need to check there is single buffer store and related value?
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -155,6 +161,277 @@ Array<Buffer> MakeScratchpads(const Array<Buffer>&
reduction_buffers, bool is_cr
return new_buffers;
}
+/*!
+ * \brief Get init value from BufferStore Node
+ * \param block The block to be checked
+ * \return The init value
+*/
+class InitUpdateValueFinder : public StmtExprVisitor {
+ public:
+ /*!
+ * \brief Find the init value of the given block
+ * \param block The block to be checked
+ * \return The init value of the given block
+ */
+ static PrimExpr FindInit(const Block& block) {
+ InitUpdateValueFinder finder;
+ finder(block->body);
+ CHECK(finder.init_value_.defined()) << "The init value of the block is not
found";
+ return finder.init_value_;
+ }
+
+ /*!
+ * \brief Find the update value of the given block
+ * \param block The block to be checked
+ * \return The update value of the given block
+ */
+ static BufferStore FindUpdate(const Block& block) {
+ InitUpdateValueFinder finder;
+ finder(block->body);
+ CHECK(finder.update_value_.defined()) << "The update value of the block is
not found";
+ return finder.update_value_;
+ }
+
+ /*!
+ * \brief Check whether the input block has MMA operation
+ * \param realize The block to be checked
+ * \return A boolean indicating whether the input block has MMA operation.
+ */
+ static bool CheckHasMMA(const Block& block) {
+ InitUpdateValueFinder checker;
+ checker(block->body);
+ return checker.has_mma_;
+ }
+
+ private:
+ void VisitStmt_(const BufferStoreNode* node) final {
+ BufferStore store = GetRef<BufferStore>(node);
+ init_value_ = store->value;
+ update_value_ = store;
+ return StmtVisitor::VisitStmt_(node);
+ }
+
+ void VisitExpr_(const CallNode* op) {
+ // TODO: Should append more test case for wmma
+ if (op->op.same_as(tir::builtin::ptx_mma())) {
+ has_mma_ = Bool(true);
+ } else if (op->op.same_as(tir::builtin::mma_fill())) {
+ has_mma_ = Bool(true);
+ init_value_ = make_const(DataType::Float(16), 0);
+ }
+ return StmtExprVisitor::VisitExpr_(op);
+ }
+
+ Bool has_mma_{false};
+ PrimExpr init_value_{nullptr};
+ BufferStore update_value_{nullptr};
+};
+
+/*!
+ * \brief Check whether the input block has child blocks.
+ * \param realize The block to be checked
+ * \return A boolean indicating whether the input block has child blocks.
+ * \note A similar check has been implemented in
"src/tir/schedule/analysis.h", but that check is
+ * based on `tir.Schedule`. Here we have no schedule information, and thus we
must implement the
+ * check again.
+ */
+class HasChildBlocksChecker : public StmtVisitor {
+ public:
+ /*!
+ * \brief Check if the given block has child blocks.
+ * \param realize The block to be checked
+ * \return True if the block has child blocks, false otherwise.
+ */
+ static bool Check(const Block& block) {
+ HasChildBlocksChecker checker;
+ checker(block->body);
+ return checker.has_child_blocks_;
+ }
+
+ /*!
+ * \brief Get child blocks of the given block realize
+ * \param realize The block to be checked
+ * \return The child blocks of the given block realize
+ */
+ static Array<BlockRealize> GetChildBlockRealizes(const Block& block) {
+ HasChildBlocksChecker checker;
+ checker(block->body);
+ return checker.child_blocks_;
+ }
+
+ private:
+ void VisitStmt_(const BlockNode* block) final {
+ has_child_blocks_ = true;
+ return StmtVisitor::VisitStmt_(block);
+ }
+
+ void VisitStmt_(const BlockRealizeNode* block_realize) final {
+ child_blocks_.push_back(GetRef<BlockRealize>(block_realize));
+ return StmtVisitor::VisitStmt_(block_realize);
+ }
+
+ bool has_child_blocks_{false};
+ Array<BlockRealize> child_blocks_;
+};
+/*!
+* \brief Visit Block Stmt and Find blocks that write the specific buffer
+*/
+
+class BufferInitBlockFinder : public StmtVisitor {
+ public:
+ /*!
+ * \brief Find the blocks that write the specific buffer
+ * \param stmt The statement to be visited
+ * \param buffer The buffer to be found
+ * \return The blocks that write the specific buffer
+ */
+ static Array<Block> Find(const Stmt& stmt, const Buffer& buffer) {
+ const Block scope_block = Downcast<Block>(stmt);
+ BufferInitBlockFinder finder(scope_block, buffer);
+ finder(stmt);
+ return finder.blocks_;
+ }
+
+ private:
+ explicit BufferInitBlockFinder(const Block& scope_block, const Buffer&
buffer)
+ : scope_block_(scope_block), buffer_(buffer) {}
+
+ void VisitStmt_(const BlockNode* block) final {
+ // We assume that init block doesn't have reads.
+ if (block->reads.size() != 0) {
+ return StmtVisitor::VisitStmt_(block);
+ }
+ // Skip the block that has no writes.
+ if (block->writes.size() == 0) {
+ return StmtVisitor::VisitStmt_(block);
+ }
+ // Skip the block that is not the dominant block.
+ if (HasChildBlocksChecker::Check(GetRef<Block>(block))) {
+ return StmtVisitor::VisitStmt_(block);
+ }
+ for (const BufferRegion& buffer_region : block->writes) {
+ if (buffer_region->buffer == buffer_) {
+ blocks_.push_back(GetRef<Block>(block));
+ }
+ }
+ return StmtVisitor::VisitStmt_(block);
+ }
+
+ const Block& scope_block_;
+ const Buffer& buffer_;
+ Array<Block> blocks_;
+};
+
+/* !
+ * \brief LoopVar Class to store the loop variables related information
+ * \brief loop_var The loop variable
+ * \brief min The minimum value of iteration
+ * \brief extent The extent of the iteration
+ * \brief kind The kind of the for loop
+ * \return The loop variables between stmt1 and stmt2
+ */
+class LoopVar {
Review Comment:
Could we just use `ForNode` object?
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -354,6 +631,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode*
realize, //
new_block->body =
BufferReplacer::Run(wb_buffers, it_buffers.value(),
std::move(new_block->body));
new_block->init = NullOpt;
+ new_block->annotations.Set(kIsCrossThreadReductionApplied, Bool(true));
Review Comment:
we could use tir::const_true()
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -120,12 +131,7 @@ bool IsReductionBlock(const BlockRealize& realize, const
Map<Var, Range>& loop_r
if (!ContainsOnlyDataParAndReductionBlockIter(block->iter_vars)) {
return false;
}
- // Cond 4. Dominant: the block is the only writer of its output, dominating
the reader of its
- // output buffers.
- if (!IsDominantBlock(scope_block, GetRef<Block>(block))) {
- return false;
- }
Review Comment:
Hi, @LeiWang1999 could you explain why the dominant check here is cancelled?
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -155,6 +161,277 @@ Array<Buffer> MakeScratchpads(const Array<Buffer>&
reduction_buffers, bool is_cr
return new_buffers;
}
+/*!
+ * \brief Get init value from BufferStore Node
+ * \param block The block to be checked
+ * \return The init value
+*/
+class InitUpdateValueFinder : public StmtExprVisitor {
+ public:
+ /*!
+ * \brief Find the init value of the given block
+ * \param block The block to be checked
+ * \return The init value of the given block
+ */
+ static PrimExpr FindInit(const Block& block) {
Review Comment:
by the usage in context, maybe we could merge the `FindInit()` and
`CheckHasMMA()` to visit the block only once.
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -527,96 +806,805 @@ Stmt TransformReductionBlock(const BlockRealizeNode*
realize, //
}
/*!
- * \brief Detect cross-thread reduction pattern and then transform
+ * \brief Inject the lowered allreduce block transformed from the input
reduction block
+ * \param realize The block-realize which contains the old reduction block
+ * \param ct_buffers The buffers to store cross-thread reduction results
+ * \param wb_buffers The buffers to store the final reduction results
+ * \param old_wb_indices The indices used to access the write-back buffers
when storing the final
+ * reduction results into the write-back buffers
+ * \param reducer The reduction function
+ * \param combiner_lhs The LHS values of the combiner
+ * \param reduction_loops The reduction loops
*/
-class CrossThreadReductionTransformer : public StmtMutator {
- private:
- // Check if the input block needs cross-thread reduction.
- std::vector<const ForNode*> NeedCrossThreadReduction(const BlockRealizeNode*
realize) {
- // Step 0. If the block is the root block, just return.
- if (block_stack_.empty()) {
- return {};
- }
+Stmt InjectReductionBlock(const BlockRealizeNode* realize,
//
+ const Array<Buffer>& ct_buffers,
//
+ const Array<Buffer>& wb_buffers,
//
+ const Array<PrimExpr>& old_wb_indices,
//
+ const CommReducer& reducer,
//
+ const Array<PrimExpr>& combiner_lhs,
//
+ const std::vector<const ForNode*>& reduction_loops
//
+ ) {
+ int n_buffers = wb_buffers.size();
+ const BlockNode* block = realize->block.get();
- // Step 1. If the block is not a reduction block, cross-thread reduction
is not needed.
- if (!IsReductionBlock(GetRef<BlockRealize>(realize), loop_range_map_,
- GetRef<Block>(block_stack_.back()), &analyzer_)) {
- return {};
+ auto f_create_buffer_regions = [](Array<Buffer> buffers) {
+ Array<BufferRegion> regions;
+ regions.reserve(buffers.size());
+ for (const Buffer& buffer : buffers) {
+ regions.push_back(BufferRegion(buffer, {Range::FromMinExtent(0, 1)}));
}
+ return regions;
+ };
- // Step 2. Collect all the vars that appear in the bindings of reduction
block iters.
- std::unordered_set<const VarNode*> reduction_vars;
- GetVarsTouchedByBlockIters(GetRef<BlockRealize>(realize), nullptr,
&reduction_vars);
+ Array<BufferRegion> ct_buffer_regions = f_create_buffer_regions(ct_buffers);
+ Optional<Array<BufferRegion>> it_buffer_regions = NullOpt;
+ // In total, the block is transformed into at most 4 statements
+ // - Stmt 1: initialize the buffer for in-thread reduction
+ // - Stmt 2: do in-thread reduction
+ // - Stmt 3: do cross-thread reduction
+ // - Stmt 4: write cross-thread reduction result to the original buffer
+ Array<Stmt> stmts;
+ stmts.reserve(4);
- // Step 3. Collect the loops whose loop vars appear in the bindings of
reduction block iters.
- // We call these loops "reduction-related".
- // Step 4. See whether at least one reduction-related loop is bound to
thread axis in GPU - if
- // so, cross-thread reduction is needed. If none of the reduction-related
loops is bound to
- // thread axis, cross-thread reduction is not needed for the input block.
- bool need = false;
- std::vector<const ForNode*> reduction_loops;
- for (const ForNode* loop : loop_stack_) {
- if (reduction_vars.count(loop->loop_var.get())) {
- // Step 3. Collect the loop.
- reduction_loops.push_back(loop);
- // Step 4. See whether the loop is bound to some thread axis.
- if (loop->thread_binding.defined()) {
- need = true;
- }
+ // Stmt 3: do cross-thread reduction
+ {
+ // Step 3.1. Create the parameters to the intrinsic
+ Array<PrimExpr> parameters;
+ parameters.reserve(reduction_loops.size() + 4);
+ // 1-st argument: number of buffers
+ parameters.push_back(make_const(DataType::UInt(32), n_buffers));
+ // Next `n_buffers` arguments: sources
+ parameters.insert(parameters.end(), combiner_lhs.begin(),
combiner_lhs.end());
+ // Next argument: predicate
+ parameters.push_back(const_true());
+ // Next `n_buffers` arguments: destinations
+ for (int i = 0; i < n_buffers; ++i) {
+ parameters.push_back(BufferLoad(ct_buffers[i], {0}));
+ }
+ // Next arguments: all the reduction threads
+ for (const ForNode* reduction_loop : reduction_loops) {
+ if (reduction_loop->thread_binding.defined()) {
+ parameters.push_back(reduction_loop->loop_var);
}
}
- return need ? reduction_loops : std::vector<const ForNode*>{};
- }
-
- // Check if the input block needs thread broadcast rewrite.
- // One block needs broadcast rewrite when
- // 1. it consumes a buffer produced by cross-thread reduction under
- // the same kernel (i.e., same group of blockIdx),
- // 2. it writes to non-local memory,
- // 3. at least one of the reduction thread vars of the cross-thread reduction
- // is free to this block (i.e., not bound to the block).
- std::vector<std::pair<ThreadScope, Range>> NeedCrossThreadBroadcast(
- const BlockRealizeNode* realize) {
- Block block = realize->block;
-
- // If the block writes to local memory, no rewrite is needed.
- for (BufferRegion write_region : block->writes) {
- if (write_region->buffer.scope() == "local") {
- return {};
+ // Step 3.2. Create the block and the block-realize.
+ Array<IterVar> iter_vars = block->iter_vars;
+ Array<PrimExpr> bindings = realize->iter_values;
+ Array<BufferRegion> reads = block->writes;
+
+ // Blockized block should also be considered
+ if (HasChildBlocksChecker::Check(GetRef<Block>(block))) {
+ Array<BlockRealize> child_blocks =
+ HasChildBlocksChecker::GetChildBlockRealizes(GetRef<Block>(block));
+
+ ICHECK_GT(child_blocks.size(), 0) << "Child blocks should be more than
0";
+
+ // If has ChildBlocks, the reads should be analyzed from the child blocks
+ reads.clear();
+ for (BlockRealize child_block : child_blocks) {
+ Array<IterVar> child_iter_vars = child_block->block->iter_vars;
+ Array<PrimExpr> child_bindings = child_block->iter_values;
+ iter_vars.insert(iter_vars.end(), child_iter_vars.begin(),
child_iter_vars.end());
+ bindings.insert(bindings.end(), child_bindings.begin(),
child_bindings.end());
+ reads.insert(reads.end(), child_block->block->writes.begin(),
+ child_block->block->writes.end());
}
}
- // Find out the reduction threads for the read-buffers which are produced
by
- // cross-thread reduction.
- std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual>
thread2range;
- for (BufferRegion read_region : block->reads) {
- auto buf_it = crt_buf2threads_.find(read_region->buffer.get());
- if (buf_it == crt_buf2threads_.end()) {
- continue;
+ // Remove unused iter vars which introduced by blockize
+ // otherwise may generate duplicated for loops
+ Array<IterVar> iter_vars_used;
+ Array<PrimExpr> bindings_used;
+ auto f_inject = [&iter_vars, &bindings, &iter_vars_used,
+ &bindings_used](const VarNode* var) -> bool {
+ for (size_t i = 0; i < iter_vars.size(); ++i) {
+ const IterVar& iter_var = iter_vars[i];
+ if (iter_var->var.get() == var) {
+ if (std::find(iter_vars_used.begin(), iter_vars_used.end(),
iter_var) ==
+ iter_vars_used.end()) {
+ iter_vars_used.push_back(iter_var);
+ bindings_used.push_back(bindings[i]);
+ }
+ return true;
+ }
}
- for (auto [scope, range] : buf_it->second) {
- thread2range[scope] = range;
+ return false;
+ };
+ for (const BufferRegion& read : reads) {
+ for (const Range& range : read->region) {
+ UsedIterVarCollector::Collect(range->min, f_inject);
+ UsedIterVarCollector::Collect(range->extent, f_inject);
}
}
- // Erase those threads which are not free to this block.
- for (const ForNode* loop : loop_stack_) {
- if (loop->thread_binding.defined()) {
- ThreadScope scope =
ThreadScope::Create(loop->thread_binding.value()->thread_tag);
- thread2range.erase(scope);
+ Block cross_thread_block =
+ Block(/*iter_vars=*/std::move(iter_vars_used),
+ /*reads=*/std::move(reads),
+ /*writes=*/ct_buffer_regions,
+ /*name_hint=*/block->name_hint + "_cross_thread",
+ /*body=*/
+ AttrStmt(/*node=*/reducer,
+ /*attr_key=*/tir::attr::reduce_scope,
+ /*value=*/make_zero(DataType::Handle()),
+ /*body=*/
+ Evaluate(Call(/*dtype=*/DataType::Handle(),
+
/*op=*/tir::builtin::tvm_thread_allreduce(),
+ /*args=*/std::move(parameters)))));
+ ObjectPtr<BlockNode> cross_thread_block_node =
+ make_object<BlockNode>(*cross_thread_block.operator->());
+ cross_thread_block_node->annotations.Set(kIsCrossThreadReductionApplied,
Bool(true));
+ stmts.push_back(BlockRealize(
+ /*iter_values=*/std::move(bindings_used),
+ /*predicate=*/const_true(),
+ /*block=*/cross_thread_block));
+ }
+ // Stmt 4: write cross-thread reduction result to the original buffer
+ {
+ ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size());
+ int n_iter = static_cast<int>(block->iter_vars.size());
+ Array<IterVar> iter_vars;
+ Array<PrimExpr> bindings;
+ Map<Var, Var> var_map;
+ Array<Range> write_region = block->writes[0]->region;
+ iter_vars.reserve(n_iter);
+ bindings.reserve(n_iter);
+ for (int i = 0; i < n_iter; ++i) {
+ const IterVar& iter_var = block->iter_vars[i];
+ const PrimExpr& binding = realize->iter_values[i];
+ if (iter_var->iter_type != kCommReduce) {
+ IterVar new_iter_var{nullptr};
+ {
+ ObjectPtr<IterVarNode> n = make_object<IterVarNode>(*iter_var.get());
+ ObjectPtr<VarNode> v = make_object<VarNode>(*iter_var->var.get());
+ n->var = Var(v);
+ new_iter_var = IterVar(n);
+ }
+ iter_vars.push_back(new_iter_var);
+ bindings.push_back(binding);
+ var_map.Set(iter_var->var, new_iter_var->var);
}
}
- std::vector<std::pair<ThreadScope, Range>> unbound_thread2range_list;
- for (auto [scope, range] : thread2range) {
- unbound_thread2range_list.emplace_back(scope, range);
- }
- return unbound_thread2range_list;
- }
- /*!
- * \brief Given that the input block needs cross-thread reduction, check if
cross-thread reduction
- * can be applied to the block (i.e., the block satisfies all necessary
conditions of cross-thread
- * reduction)
+ if (HasChildBlocksChecker::Check(GetRef<Block>(block))) {
+ Array<BlockRealize> child_blocks =
+ HasChildBlocksChecker::GetChildBlockRealizes(GetRef<Block>(block));
+ ICHECK_GT(child_blocks.size(), 0) << "Child blocks should be more than
0";
+
+ for (BlockRealize child_block : child_blocks) {
+ Array<IterVar> child_iter_vars = child_block->block->iter_vars;
+ Array<PrimExpr> child_bindings = child_block->iter_values;
+ iter_vars.insert(iter_vars.end(), child_iter_vars.begin(),
child_iter_vars.end());
+ bindings.insert(bindings.end(), child_bindings.begin(),
child_bindings.end());
+ write_region = child_block->block->writes[0]->region;
+ }
+ }
+ Array<Stmt> wb_updates;
+ Array<BufferRegion> wb_regions;
+ wb_updates.reserve(n_buffers);
+ wb_regions.reserve(n_buffers);
+ int n_dim = static_cast<int>(old_wb_indices.size());
+ Array<Range> region = Substitute(write_region, var_map);
+ Array<PrimExpr> wb_indices;
+ wb_indices.reserve(n_dim);
+ for (int d = 0; d < n_dim; ++d) {
+ wb_indices.push_back(Substitute(old_wb_indices[d], var_map));
+ }
+ for (int i = 0; i < n_buffers; ++i) {
+ wb_updates.push_back(
+ BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}),
wb_indices));
+ wb_regions.push_back(BufferRegion(wb_buffers[i], region));
+ }
+
+ // Construct the predicate of the write-back block. It is the conjunction
of
+ // - each predicate clause of the original block which contains spatial
loop var, and
+ // - `t == 0` for each reduction thread dim when the write-back buffer is
not local.
+ PrimExpr wb_predicate = const_true();
+ std::unordered_set<const VarNode*> reduction_loop_vars;
+ reduction_loop_vars.reserve(reduction_loops.size());
+ for (const ForNode* reduction_loop : reduction_loops) {
+ reduction_loop_vars.insert(reduction_loop->loop_var.get());
+ }
+ PostOrderVisit(realize->predicate, [&wb_predicate,
&reduction_loop_vars](const ObjectRef& obj) {
+ if (const auto* and_node = obj.as<AndNode>()) {
+ Array<PrimExpr> sub_exprs = {and_node->a, and_node->b};
+ for (PrimExpr sub_expr : sub_exprs) {
+ if (sub_expr->IsInstance<AndNode>()) {
+ continue;
+ }
+ bool is_reduction = [sub_expr, &reduction_loop_vars]() {
+ Array<Var> vars = UndefinedVars(sub_expr);
+ for (Var var : vars) {
+ if (reduction_loop_vars.find(var.get()) !=
reduction_loop_vars.end()) {
+ return true;
+ }
+ }
+ return false;
+ }();
+ if (!is_reduction) {
+ wb_predicate = wb_predicate && sub_expr;
+ }
+ }
+ return true;
+ }
+ return false;
+ });
+ if (wb_buffers[0].scope() != "local") {
+ for (const ForNode* loop : reduction_loops) {
+ if (loop->thread_binding.defined()) {
+ wb_predicate = wb_predicate && (loop->loop_var ==
IntImm(loop->loop_var->dtype, 0));
+ }
+ }
+ }
+ // Remove unused iter vars which introduced by blockize
+ // otherwise may generate duplicated for loops
+ Array<IterVar> iter_vars_used;
+ Array<PrimExpr> bindings_used;
+ auto f_inject = [&iter_vars, &bindings, &iter_vars_used,
+ &bindings_used](const VarNode* var) -> bool {
+ for (size_t i = 0; i < iter_vars.size(); ++i) {
+ const IterVar& iter_var = iter_vars[i];
+ if (iter_var->var.get() == var) {
+ if (std::find(iter_vars_used.begin(), iter_vars_used.end(),
iter_var) ==
+ iter_vars_used.end()) {
+ iter_vars_used.push_back(iter_var);
+ bindings_used.push_back(bindings[i]);
+ }
+ return true;
+ }
+ }
+ return false;
+ };
+ for (const BufferRegion& write : wb_regions) {
+ for (const Range& range : write->region) {
+ UsedIterVarCollector::Collect(range->min, f_inject);
+ UsedIterVarCollector::Collect(range->extent, f_inject);
+ }
+ }
+ stmts.push_back(BlockRealize(
+ /*iter_values=*/std::move(bindings_used),
+ /*predicate=*/wb_predicate,
+ /*block=*/
+ Block(/*iter_vars=*/std::move(iter_vars_used),
+ /*reads=*/std::move(ct_buffer_regions),
+ /*writes=*/std::move(wb_regions),
+ /*name_hint=*/block->name_hint + "_write_back",
+ /*body=*/n_buffers > 1 ? SeqStmt(wb_updates) : wb_updates[0])));
+ }
+ // Final step: Wrap all the above four statements with the reduction loops
bound to threadIdx
+ Stmt new_stmt = Stmt();
+ for (auto rit = reduction_loops.rbegin(); rit != reduction_loops.rend();
++rit) {
+ const ForNode* loop = *rit;
+ if (loop->thread_binding.defined()) {
+ // Colelct Loop vars between the reduction lops
+ std::vector<LoopVar> chain_loop_vars =
+ LoopVarCollector::Collect(loop->body, GetRef<Block>(block));
+ std::vector<LoopVar> used_chain_loop_vars_array;
+ if (HasChildBlocksChecker::Check(GetRef<Block>(block))) {
+ chain_loop_vars.clear();
+ Array<BlockRealize> child_blocks =
+ HasChildBlocksChecker::GetChildBlockRealizes(GetRef<Block>(block));
+ for (BlockRealize child_block : child_blocks) {
+ std::vector<LoopVar> child_loop_vars =
+ LoopVarCollector::Collect(loop->body, child_block->block);
+ chain_loop_vars.insert(chain_loop_vars.end(),
child_loop_vars.begin(),
+ child_loop_vars.end());
+ }
+ }
+
+ // Remove Unused Loop from the chain loops, otherwise may generate
duplicated for loops
+ for (auto it = chain_loop_vars.begin(); it != chain_loop_vars.end();
++it) {
+ Var target_var = (*it).loop_var;
+ auto f_find = [&target_var](const VarNode* var) -> bool {
+ if (target_var.get() == var) {
+ return true;
+ }
+ return false;
+ };
+ for (const Stmt& stmt : stmts) {
+ if (UsesVar(stmt, f_find)) {
+ used_chain_loop_vars_array.push_back(*it);
+ break;
+ }
+ }
+ }
+ chain_loop_vars = used_chain_loop_vars_array;
+
+ ObjectPtr<ForNode> n = make_object<ForNode>(*loop);
+ if (chain_loop_vars.size() == 0) {
+ stmts.insert(stmts.begin(), n->body);
+ new_stmt = SeqStmt::Flatten(std::move(stmts));
+ n->body = std::move(new_stmt);
+ new_stmt = For(n);
+ break;
+ } else {
+ new_stmt = SeqStmt::Flatten(std::move(stmts));
+ For new_for = For(chain_loop_vars.back().loop_var,
chain_loop_vars.back().min,
+ chain_loop_vars.back().extent,
chain_loop_vars.back().kind, new_stmt);
+
+ ObjectPtr<ForNode> current_loop = make_object<ForNode>(*new_for.get());
+ for (int i = chain_loop_vars.size() - 2; i >= 0; i--) {
+ new_for = For(chain_loop_vars[i].loop_var, chain_loop_vars[i].min,
+ chain_loop_vars[i].extent, chain_loop_vars[i].kind,
new_for);
+ }
+ new_stmt = SeqStmt::Flatten(std::move((SeqStmt({std::move(n->body),
std::move(new_for)}))));
+ n->body = std::move(new_stmt);
+ new_stmt = For(n);
+ break;
+ }
+ }
+ }
+ return new_stmt;
+}
+
+/*!
+ * \brief Inject the lowered warp evaluate allreduce block transformed from
the input reduction
+ * block
+ * \param realize The block-realize which contains the old reduction block
+ * \param ct_buffers The buffers to store cross-thread reduction results
+ * \param wb_buffers The buffers to store the final reduction results
+ * \param old_wb_indices The indices used to access the write-back buffers
when storing the final
+ * reduction results into the write-back buffers
+ * \param reducer The reduction function
+ * \param combiner_lhs The LHS values of the combiner
+ * \param reduction_loops The reduction loops
+ * \param used_block The block used in the warp evaluate (init block)
+ */
+Stmt InjectWarpEvaluateReductionBlock(const BlockRealizeNode* realize,
//
+ const Array<Buffer>& ct_buffers,
//
+ const Array<Buffer>& wb_buffers,
//
+ const Array<PrimExpr>& old_wb_indices,
//
+ const CommReducer& reducer,
//
+ const Array<PrimExpr>&
combiner_evaluate, //
+ const std::vector<const ForNode*>&
reduction_loops //
+ ) {
+ int n_buffers = wb_buffers.size();
+ const BlockNode* block = realize->block.get();
+ Buffer write_buffer = (*block->writes.begin())->buffer;
+ PrimExpr warp_size = write_buffer->shape[write_buffer->shape.size() - 2];
+ PrimExpr local_size = write_buffer->shape[write_buffer->shape.size() - 1];
+ // Create IterVars
+ Var v_lane_id = Var("v_lane_id");
+ Var v_local_id = Var("v_local_id");
+ // Create Bindings
+ Var ax_lane_id = Var("ax_lane_id");
+ Var ax_local_id = Var("ax_local_id");
+ // Create IterVar
+ IterVar lane_id = IterVar(Range(0, warp_size), v_lane_id, kDataPar,
"threadIdx.x");
+ IterVar local_id = IterVar(Range(0, local_size), v_local_id, kDataPar);
+
+ auto f_create_buffer_regions = [](Array<Buffer> buffers) {
+ Array<BufferRegion> regions;
+ regions.reserve(buffers.size());
+ for (const Buffer& buffer : buffers) {
+ regions.push_back(BufferRegion(buffer, {Range::FromMinExtent(0, 1)}));
+ }
+ return regions;
+ };
+
+ Array<BufferRegion> ct_buffer_regions = f_create_buffer_regions(ct_buffers);
+ Optional<Array<BufferRegion>> it_buffer_regions = NullOpt;
+ // In total, the block is transformed into at most 4 statements
+ // - Stmt 1: initialize the buffer for in-thread reduction
+ // - Stmt 2: do in-thread reduction
+ // - Stmt 3: do cross-thread reduction
+ // - Stmt 4: write cross-thread reduction result to the original buffer
+ Array<Stmt> stmts;
+ stmts.reserve(4);
+
+ // Stmt 3: do cross-thread reduction
+ {
+ // Step 3.1. Create the parameters to the intrinsic
+ Array<PrimExpr> parameters;
+
+ // Step 3.2. Create the block and the block-realize.
+ Array<IterVar> iter_vars = block->iter_vars;
+ Array<PrimExpr> bindings = realize->iter_values;
+ Array<BufferRegion> reads = block->writes;
+
+ // Blockized block should also be considered
+ if (HasChildBlocksChecker::Check(GetRef<Block>(block))) {
+ Array<BlockRealize> child_blocks =
+ HasChildBlocksChecker::GetChildBlockRealizes(GetRef<Block>(block));
+
+ ICHECK_GT(child_blocks.size(), 0) << "Child blocks should be more than
0";
+
+ // If has ChildBlocks, the reads should be analyzed from the child blocks
+ reads.clear();
+ for (BlockRealize child_block : child_blocks) {
+ Array<IterVar> child_iter_vars = child_block->block->iter_vars;
+ Array<PrimExpr> child_bindings = child_block->iter_values;
+ iter_vars.insert(iter_vars.end(), child_iter_vars.begin(),
child_iter_vars.end());
+ bindings.insert(bindings.end(), child_bindings.begin(),
child_bindings.end());
+ reads.insert(reads.end(), child_block->block->writes.begin(),
+ child_block->block->writes.end());
+ }
+ }
+
+ // Remove unused iter vars which introduced by blockize
+ // otherwise may generate duplicated for loops
+ Array<IterVar> iter_vars_used;
+ Array<PrimExpr> bindings_used;
+ auto f_inject = [&iter_vars, &bindings, &iter_vars_used,
+ &bindings_used](const VarNode* var) -> bool {
+ for (size_t i = 0; i < iter_vars.size(); ++i) {
+ const IterVar& iter_var = iter_vars[i];
+ if (iter_var->var.get() == var) {
+ if (std::find(iter_vars_used.begin(), iter_vars_used.end(),
iter_var) ==
+ iter_vars_used.end()) {
+ iter_vars_used.push_back(iter_var);
+ bindings_used.push_back(bindings[i]);
+ }
+ return true;
+ }
+ }
+ return false;
+ };
+ for (const BufferRegion& read : reads) {
+ for (const Range& range : read->region) {
+ UsedIterVarCollector::Collect(range->min, f_inject);
+ UsedIterVarCollector::Collect(range->extent, f_inject);
+ }
+ }
+
+ // Create the vars for the warp
+ // Collect the Warp Information from the Buffer
+ // The threadIdx and Warp Information is stored in the last two dimensions
of the write buffer
+
+ iter_vars_used.push_back(lane_id);
+ iter_vars_used.push_back(local_id);
+
+ bindings_used.push_back(ax_lane_id);
+ bindings_used.push_back(ax_local_id);
+
+ Array<PrimExpr> combiner_warp;
+ for (size_t i = 0; i < combiner_evaluate.size(); i++) {
+ Array<PrimExpr> new_indices; // indices that maintain the new indices
with warp information
+ for (size_t j = 0; j <
combiner_evaluate[i].as<BufferLoadNode>()->indices.size() - 2; j++) {
+
new_indices.push_back(combiner_evaluate[i].as<BufferLoadNode>()->indices[j]);
+ }
+ new_indices.push_back(v_lane_id);
+ new_indices.push_back(v_local_id);
+ combiner_warp.push_back(
+ BufferLoad(combiner_evaluate[i].as<BufferLoadNode>()->buffer,
new_indices));
+ }
+ // Create the parameters to the intrinsic
+ parameters.reserve(reduction_loops.size() + 4);
+ // 1-st argument: number of buffers
+ parameters.push_back(make_const(DataType::UInt(32), n_buffers));
+ // Next `n_buffers` arguments: sources
+ parameters.insert(parameters.end(), combiner_warp.begin(),
combiner_warp.end());
+ // Next argument: predicate
+ parameters.push_back(const_true());
+ // Next `n_buffers` arguments: destinations
+ for (int i = 0; i < n_buffers; ++i) {
+ parameters.push_back(BufferLoad(ct_buffers[i], {0}));
+ }
+ // Next arguments: all the reduction threads
+ for (const ForNode* reduction_loop : reduction_loops) {
+ if (reduction_loop->thread_binding.defined()) {
+ parameters.push_back(reduction_loop->loop_var);
+ }
+ }
+ // update param indices
+ Block cross_thread_block =
+ Block(/*iter_vars=*/std::move(iter_vars_used),
+ /*reads=*/std::move(reads),
+ /*writes=*/ct_buffer_regions,
+ /*name_hint=*/block->name_hint + "_cross_thread",
+ /*body=*/
+ AttrStmt(/*node=*/reducer,
+ /*attr_key=*/tir::attr::reduce_scope,
+ /*value=*/make_zero(DataType::Handle()),
+ /*body=*/
+ Evaluate(Call(/*dtype=*/DataType::Handle(),
+
/*op=*/tir::builtin::tvm_thread_allreduce(),
+ /*args=*/std::move(parameters)))));
+
+ ObjectPtr<BlockNode> cross_thread_block_node =
+ make_object<BlockNode>(*cross_thread_block.operator->());
+ cross_thread_block_node->annotations.Set(kIsCrossThreadReductionApplied,
Bool(true));
+ stmts.push_back(BlockRealize(
+ /*iter_values=*/std::move(bindings_used),
+ /*predicate=*/const_true(),
+ /*block=*/cross_thread_block));
+ }
+ // Stmt 4: write cross-thread reduction result to the original buffer
+ {
+ ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size());
+ int n_iter = static_cast<int>(block->iter_vars.size());
+ Array<IterVar> iter_vars;
+ Array<PrimExpr> bindings;
+ Array<Range> write_region = block->writes[0]->region;
+ iter_vars.reserve(n_iter);
+ bindings.reserve(n_iter);
+ for (int i = 0; i < n_iter; ++i) {
+ const IterVar& iter_var = block->iter_vars[i];
+ const PrimExpr& binding = realize->iter_values[i];
+ if (iter_var->iter_type != kCommReduce) {
+ iter_vars.push_back(iter_var);
+ bindings.push_back(binding);
+ }
+ }
+
+ if (HasChildBlocksChecker::Check(GetRef<Block>(block))) {
+ Array<BlockRealize> child_blocks =
+ HasChildBlocksChecker::GetChildBlockRealizes(GetRef<Block>(block));
+ ICHECK_GT(child_blocks.size(), 0) << "Child blocks should be more than
0";
+
+ for (BlockRealize child_block : child_blocks) {
+ Array<IterVar> child_iter_vars = child_block->block->iter_vars;
+ Array<PrimExpr> child_bindings = child_block->iter_values;
+ iter_vars.insert(iter_vars.end(), child_iter_vars.begin(),
child_iter_vars.end());
+ bindings.insert(bindings.end(), child_bindings.begin(),
child_bindings.end());
+ write_region = child_block->block->writes[0]->region;
+ }
+ }
+ Array<Stmt> wb_updates;
+ Array<BufferRegion> wb_regions;
+ wb_updates.reserve(n_buffers);
+ wb_regions.reserve(n_buffers);
+
+ Array<Range> region = write_region;
+ Array<PrimExpr> wb_indices;
+ for (size_t i = 0; i < block->writes[0]->region.size() - 2; i++) {
+ Range range = block->writes[0]->region[i];
+ wb_indices.push_back(range->min);
+ }
+ // Append the warp and local id to the indices
+ wb_indices.push_back(v_lane_id);
+ wb_indices.push_back(v_local_id);
+ for (int i = 0; i < n_buffers; ++i) {
+ wb_updates.push_back(
+ BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}),
wb_indices));
+ wb_regions.push_back(BufferRegion(wb_buffers[i], region));
+ }
+
+ // Construct the predicate of the write-back block. It is the conjunction
of
+ // - each predicate clause of the original block which contains spatial
loop var, and
+ // - `t == 0` for each reduction thread dim when the write-back buffer is
not local.
+ PrimExpr wb_predicate = const_true();
+ std::unordered_set<const VarNode*> reduction_loop_vars;
+ reduction_loop_vars.reserve(reduction_loops.size());
+ for (const ForNode* reduction_loop : reduction_loops) {
+ reduction_loop_vars.insert(reduction_loop->loop_var.get());
+ }
+ PostOrderVisit(realize->predicate, [&wb_predicate,
&reduction_loop_vars](const ObjectRef& obj) {
+ if (const auto* and_node = obj.as<AndNode>()) {
+ Array<PrimExpr> sub_exprs = {and_node->a, and_node->b};
+ for (PrimExpr sub_expr : sub_exprs) {
+ if (sub_expr->IsInstance<AndNode>()) {
+ continue;
+ }
+ bool is_reduction = [sub_expr, &reduction_loop_vars]() {
+ Array<Var> vars = UndefinedVars(sub_expr);
+ for (Var var : vars) {
+ if (reduction_loop_vars.find(var.get()) !=
reduction_loop_vars.end()) {
+ return true;
+ }
+ }
+ return false;
+ }();
+ if (!is_reduction) {
+ wb_predicate = wb_predicate && sub_expr;
+ }
+ }
+ return true;
+ }
+ return false;
+ });
+ if (wb_buffers[0].scope() != "local") {
+ for (const ForNode* loop : reduction_loops) {
+ if (loop->thread_binding.defined()) {
+ wb_predicate = wb_predicate && (loop->loop_var ==
IntImm(loop->loop_var->dtype, 0));
+ }
+ }
+ }
+ // Remove unused iter vars which introduced by blockize
+ // otherwise may generate duplicated for loops
+ Array<IterVar> iter_vars_used;
+ Array<PrimExpr> bindings_used;
+ auto f_inject = [&iter_vars, &bindings, &iter_vars_used,
+ &bindings_used](const VarNode* var) -> bool {
+ for (size_t i = 0; i < iter_vars.size(); ++i) {
+ const IterVar& iter_var = iter_vars[i];
+ if (iter_var->var.get() == var) {
+ if (std::find(iter_vars_used.begin(), iter_vars_used.end(),
iter_var) ==
+ iter_vars_used.end()) {
+ iter_vars_used.push_back(iter_var);
+ bindings_used.push_back(bindings[i]);
+ }
+ return true;
+ }
+ }
+ return false;
+ };
+ for (const BufferRegion& write : wb_regions) {
+ for (const Range& range : write->region) {
+ UsedIterVarCollector::Collect(range->min, f_inject);
+ UsedIterVarCollector::Collect(range->extent, f_inject);
+ }
+ }
+
+ iter_vars_used.push_back(IterVar(Range(0, warp_size), v_lane_id, kDataPar,
"threadIdx.x"));
+ iter_vars_used.push_back(IterVar(Range(0, local_size), v_local_id,
kDataPar));
+
+ bindings_used.push_back(ax_lane_id);
+ bindings_used.push_back(ax_local_id);
+ stmts.push_back(BlockRealize(
+ /*iter_values=*/std::move(bindings_used),
+ /*predicate=*/wb_predicate,
+ /*block=*/
+ Block(/*iter_vars=*/std::move(iter_vars_used),
+ /*reads=*/std::move(ct_buffer_regions),
+ /*writes=*/std::move(wb_regions),
+ /*name_hint=*/block->name_hint + "_write_back",
+ /*body=*/n_buffers > 1 ? SeqStmt(wb_updates) : wb_updates[0])));
+ }
+ // Final step: Wrap all the above four statements with the reduction loops
bound to threadIdx
+ Stmt new_stmt = Stmt();
+ for (auto rit = reduction_loops.rbegin(); rit != reduction_loops.rend();
++rit) {
+ const ForNode* loop = *rit;
+ if (loop->thread_binding.defined()) {
+ // Colelct Loop vars between the reduction lops
Review Comment:
typo
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -155,6 +161,277 @@ Array<Buffer> MakeScratchpads(const Array<Buffer>&
reduction_buffers, bool is_cr
return new_buffers;
}
+/*!
+ * \brief Get init value from BufferStore Node
+ * \param block The block to be checked
+ * \return The init value
+*/
+class InitUpdateValueFinder : public StmtExprVisitor {
+ public:
+ /*!
+ * \brief Find the init value of the given block
+ * \param block The block to be checked
+ * \return The init value of the given block
+ */
+ static PrimExpr FindInit(const Block& block) {
+ InitUpdateValueFinder finder;
+ finder(block->body);
+ CHECK(finder.init_value_.defined()) << "The init value of the block is not
found";
+ return finder.init_value_;
+ }
+
+ /*!
+ * \brief Find the update value of the given block
+ * \param block The block to be checked
+ * \return The update value of the given block
+ */
+ static BufferStore FindUpdate(const Block& block) {
+ InitUpdateValueFinder finder;
+ finder(block->body);
+ CHECK(finder.update_value_.defined()) << "The update value of the block is
not found";
+ return finder.update_value_;
+ }
+
+ /*!
+ * \brief Check whether the input block has MMA operation
+ * \param realize The block to be checked
+ * \return A boolean indicating whether the input block has MMA operation.
+ */
+ static bool CheckHasMMA(const Block& block) {
+ InitUpdateValueFinder checker;
+ checker(block->body);
+ return checker.has_mma_;
+ }
+
+ private:
+ void VisitStmt_(const BufferStoreNode* node) final {
+ BufferStore store = GetRef<BufferStore>(node);
+ init_value_ = store->value;
+ update_value_ = store;
+ return StmtVisitor::VisitStmt_(node);
+ }
+
+ void VisitExpr_(const CallNode* op) {
+ // TODO: Should append more test case for wmma
+ if (op->op.same_as(tir::builtin::ptx_mma())) {
+ has_mma_ = Bool(true);
+ } else if (op->op.same_as(tir::builtin::mma_fill())) {
+ has_mma_ = Bool(true);
+ init_value_ = make_const(DataType::Float(16), 0);
+ }
+ return StmtExprVisitor::VisitExpr_(op);
+ }
+
+ Bool has_mma_{false};
Review Comment:
boxed boolean object seems to be unnessesary
##########
src/tir/transforms/lower_cross_thread_reduction.cc:
##########
@@ -527,96 +806,805 @@ Stmt TransformReductionBlock(const BlockRealizeNode*
realize, //
}
/*!
- * \brief Detect cross-thread reduction pattern and then transform
+ * \brief Inject the lowered allreduce block transformed from the input
reduction block
+ * \param realize The block-realize which contains the old reduction block
+ * \param ct_buffers The buffers to store cross-thread reduction results
+ * \param wb_buffers The buffers to store the final reduction results
+ * \param old_wb_indices The indices used to access the write-back buffers
when storing the final
+ * reduction results into the write-back buffers
+ * \param reducer The reduction function
+ * \param combiner_lhs The LHS values of the combiner
+ * \param reduction_loops The reduction loops
*/
-class CrossThreadReductionTransformer : public StmtMutator {
- private:
- // Check if the input block needs cross-thread reduction.
- std::vector<const ForNode*> NeedCrossThreadReduction(const BlockRealizeNode*
realize) {
- // Step 0. If the block is the root block, just return.
- if (block_stack_.empty()) {
- return {};
- }
+Stmt InjectReductionBlock(const BlockRealizeNode* realize,
//
+ const Array<Buffer>& ct_buffers,
//
+ const Array<Buffer>& wb_buffers,
//
+ const Array<PrimExpr>& old_wb_indices,
//
+ const CommReducer& reducer,
//
+ const Array<PrimExpr>& combiner_lhs,
//
+ const std::vector<const ForNode*>& reduction_loops
//
+ ) {
+ int n_buffers = wb_buffers.size();
+ const BlockNode* block = realize->block.get();
- // Step 1. If the block is not a reduction block, cross-thread reduction
is not needed.
- if (!IsReductionBlock(GetRef<BlockRealize>(realize), loop_range_map_,
- GetRef<Block>(block_stack_.back()), &analyzer_)) {
- return {};
+ auto f_create_buffer_regions = [](Array<Buffer> buffers) {
+ Array<BufferRegion> regions;
+ regions.reserve(buffers.size());
+ for (const Buffer& buffer : buffers) {
+ regions.push_back(BufferRegion(buffer, {Range::FromMinExtent(0, 1)}));
}
+ return regions;
+ };
- // Step 2. Collect all the vars that appear in the bindings of reduction
block iters.
- std::unordered_set<const VarNode*> reduction_vars;
- GetVarsTouchedByBlockIters(GetRef<BlockRealize>(realize), nullptr,
&reduction_vars);
+ Array<BufferRegion> ct_buffer_regions = f_create_buffer_regions(ct_buffers);
+ Optional<Array<BufferRegion>> it_buffer_regions = NullOpt;
+ // In total, the block is transformed into at most 4 statements
+ // - Stmt 1: initialize the buffer for in-thread reduction
+ // - Stmt 2: do in-thread reduction
+ // - Stmt 3: do cross-thread reduction
+ // - Stmt 4: write cross-thread reduction result to the original buffer
+ Array<Stmt> stmts;
+ stmts.reserve(4);
- // Step 3. Collect the loops whose loop vars appear in the bindings of
reduction block iters.
- // We call these loops "reduction-related".
- // Step 4. See whether at least one reduction-related loop is bound to
thread axis in GPU - if
- // so, cross-thread reduction is needed. If none of the reduction-related
loops is bound to
- // thread axis, cross-thread reduction is not needed for the input block.
- bool need = false;
- std::vector<const ForNode*> reduction_loops;
- for (const ForNode* loop : loop_stack_) {
- if (reduction_vars.count(loop->loop_var.get())) {
- // Step 3. Collect the loop.
- reduction_loops.push_back(loop);
- // Step 4. See whether the loop is bound to some thread axis.
- if (loop->thread_binding.defined()) {
- need = true;
- }
+ // Stmt 3: do cross-thread reduction
+ {
+ // Step 3.1. Create the parameters to the intrinsic
+ Array<PrimExpr> parameters;
+ parameters.reserve(reduction_loops.size() + 4);
+ // 1-st argument: number of buffers
+ parameters.push_back(make_const(DataType::UInt(32), n_buffers));
+ // Next `n_buffers` arguments: sources
+ parameters.insert(parameters.end(), combiner_lhs.begin(),
combiner_lhs.end());
+ // Next argument: predicate
+ parameters.push_back(const_true());
+ // Next `n_buffers` arguments: destinations
+ for (int i = 0; i < n_buffers; ++i) {
+ parameters.push_back(BufferLoad(ct_buffers[i], {0}));
+ }
+ // Next arguments: all the reduction threads
+ for (const ForNode* reduction_loop : reduction_loops) {
+ if (reduction_loop->thread_binding.defined()) {
+ parameters.push_back(reduction_loop->loop_var);
}
}
- return need ? reduction_loops : std::vector<const ForNode*>{};
- }
-
- // Check if the input block needs thread broadcast rewrite.
- // One block needs broadcast rewrite when
- // 1. it consumes a buffer produced by cross-thread reduction under
- // the same kernel (i.e., same group of blockIdx),
- // 2. it writes to non-local memory,
- // 3. at least one of the reduction thread vars of the cross-thread reduction
- // is free to this block (i.e., not bound to the block).
- std::vector<std::pair<ThreadScope, Range>> NeedCrossThreadBroadcast(
- const BlockRealizeNode* realize) {
- Block block = realize->block;
-
- // If the block writes to local memory, no rewrite is needed.
- for (BufferRegion write_region : block->writes) {
- if (write_region->buffer.scope() == "local") {
- return {};
+ // Step 3.2. Create the block and the block-realize.
+ Array<IterVar> iter_vars = block->iter_vars;
+ Array<PrimExpr> bindings = realize->iter_values;
+ Array<BufferRegion> reads = block->writes;
+
+ // Blockized block should also be considered
+ if (HasChildBlocksChecker::Check(GetRef<Block>(block))) {
+ Array<BlockRealize> child_blocks =
+ HasChildBlocksChecker::GetChildBlockRealizes(GetRef<Block>(block));
+
+ ICHECK_GT(child_blocks.size(), 0) << "Child blocks should be more than
0";
+
+ // If has ChildBlocks, the reads should be analyzed from the child blocks
+ reads.clear();
+ for (BlockRealize child_block : child_blocks) {
+ Array<IterVar> child_iter_vars = child_block->block->iter_vars;
+ Array<PrimExpr> child_bindings = child_block->iter_values;
+ iter_vars.insert(iter_vars.end(), child_iter_vars.begin(),
child_iter_vars.end());
+ bindings.insert(bindings.end(), child_bindings.begin(),
child_bindings.end());
+ reads.insert(reads.end(), child_block->block->writes.begin(),
+ child_block->block->writes.end());
}
}
- // Find out the reduction threads for the read-buffers which are produced
by
- // cross-thread reduction.
- std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual>
thread2range;
- for (BufferRegion read_region : block->reads) {
- auto buf_it = crt_buf2threads_.find(read_region->buffer.get());
- if (buf_it == crt_buf2threads_.end()) {
- continue;
+ // Remove unused iter vars which introduced by blockize
+ // otherwise may generate duplicated for loops
+ Array<IterVar> iter_vars_used;
+ Array<PrimExpr> bindings_used;
+ auto f_inject = [&iter_vars, &bindings, &iter_vars_used,
+ &bindings_used](const VarNode* var) -> bool {
+ for (size_t i = 0; i < iter_vars.size(); ++i) {
+ const IterVar& iter_var = iter_vars[i];
+ if (iter_var->var.get() == var) {
+ if (std::find(iter_vars_used.begin(), iter_vars_used.end(),
iter_var) ==
+ iter_vars_used.end()) {
+ iter_vars_used.push_back(iter_var);
+ bindings_used.push_back(bindings[i]);
+ }
+ return true;
+ }
}
- for (auto [scope, range] : buf_it->second) {
- thread2range[scope] = range;
+ return false;
+ };
+ for (const BufferRegion& read : reads) {
+ for (const Range& range : read->region) {
+ UsedIterVarCollector::Collect(range->min, f_inject);
+ UsedIterVarCollector::Collect(range->extent, f_inject);
}
}
- // Erase those threads which are not free to this block.
- for (const ForNode* loop : loop_stack_) {
- if (loop->thread_binding.defined()) {
- ThreadScope scope =
ThreadScope::Create(loop->thread_binding.value()->thread_tag);
- thread2range.erase(scope);
+ Block cross_thread_block =
+ Block(/*iter_vars=*/std::move(iter_vars_used),
+ /*reads=*/std::move(reads),
+ /*writes=*/ct_buffer_regions,
+ /*name_hint=*/block->name_hint + "_cross_thread",
+ /*body=*/
+ AttrStmt(/*node=*/reducer,
+ /*attr_key=*/tir::attr::reduce_scope,
+ /*value=*/make_zero(DataType::Handle()),
+ /*body=*/
+ Evaluate(Call(/*dtype=*/DataType::Handle(),
+
/*op=*/tir::builtin::tvm_thread_allreduce(),
+ /*args=*/std::move(parameters)))));
+ ObjectPtr<BlockNode> cross_thread_block_node =
+ make_object<BlockNode>(*cross_thread_block.operator->());
+ cross_thread_block_node->annotations.Set(kIsCrossThreadReductionApplied,
Bool(true));
+ stmts.push_back(BlockRealize(
+ /*iter_values=*/std::move(bindings_used),
+ /*predicate=*/const_true(),
+ /*block=*/cross_thread_block));
+ }
+ // Stmt 4: write cross-thread reduction result to the original buffer
+ {
+ ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size());
+ int n_iter = static_cast<int>(block->iter_vars.size());
+ Array<IterVar> iter_vars;
+ Array<PrimExpr> bindings;
+ Map<Var, Var> var_map;
+ Array<Range> write_region = block->writes[0]->region;
+ iter_vars.reserve(n_iter);
+ bindings.reserve(n_iter);
+ for (int i = 0; i < n_iter; ++i) {
+ const IterVar& iter_var = block->iter_vars[i];
+ const PrimExpr& binding = realize->iter_values[i];
+ if (iter_var->iter_type != kCommReduce) {
+ IterVar new_iter_var{nullptr};
+ {
+ ObjectPtr<IterVarNode> n = make_object<IterVarNode>(*iter_var.get());
+ ObjectPtr<VarNode> v = make_object<VarNode>(*iter_var->var.get());
+ n->var = Var(v);
+ new_iter_var = IterVar(n);
+ }
+ iter_vars.push_back(new_iter_var);
+ bindings.push_back(binding);
+ var_map.Set(iter_var->var, new_iter_var->var);
}
}
- std::vector<std::pair<ThreadScope, Range>> unbound_thread2range_list;
- for (auto [scope, range] : thread2range) {
- unbound_thread2range_list.emplace_back(scope, range);
- }
- return unbound_thread2range_list;
- }
- /*!
- * \brief Given that the input block needs cross-thread reduction, check if
cross-thread reduction
- * can be applied to the block (i.e., the block satisfies all necessary
conditions of cross-thread
- * reduction)
+ if (HasChildBlocksChecker::Check(GetRef<Block>(block))) {
+ Array<BlockRealize> child_blocks =
+ HasChildBlocksChecker::GetChildBlockRealizes(GetRef<Block>(block));
+ ICHECK_GT(child_blocks.size(), 0) << "Child blocks should be more than
0";
+
+ for (BlockRealize child_block : child_blocks) {
+ Array<IterVar> child_iter_vars = child_block->block->iter_vars;
+ Array<PrimExpr> child_bindings = child_block->iter_values;
+ iter_vars.insert(iter_vars.end(), child_iter_vars.begin(),
child_iter_vars.end());
+ bindings.insert(bindings.end(), child_bindings.begin(),
child_bindings.end());
+ write_region = child_block->block->writes[0]->region;
+ }
+ }
+ Array<Stmt> wb_updates;
+ Array<BufferRegion> wb_regions;
+ wb_updates.reserve(n_buffers);
+ wb_regions.reserve(n_buffers);
+ int n_dim = static_cast<int>(old_wb_indices.size());
+ Array<Range> region = Substitute(write_region, var_map);
+ Array<PrimExpr> wb_indices;
+ wb_indices.reserve(n_dim);
+ for (int d = 0; d < n_dim; ++d) {
+ wb_indices.push_back(Substitute(old_wb_indices[d], var_map));
+ }
+ for (int i = 0; i < n_buffers; ++i) {
+ wb_updates.push_back(
+ BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}),
wb_indices));
+ wb_regions.push_back(BufferRegion(wb_buffers[i], region));
+ }
+
+ // Construct the predicate of the write-back block. It is the conjunction
of
+ // - each predicate clause of the original block which contains spatial
loop var, and
+ // - `t == 0` for each reduction thread dim when the write-back buffer is
not local.
+ PrimExpr wb_predicate = const_true();
+ std::unordered_set<const VarNode*> reduction_loop_vars;
+ reduction_loop_vars.reserve(reduction_loops.size());
+ for (const ForNode* reduction_loop : reduction_loops) {
+ reduction_loop_vars.insert(reduction_loop->loop_var.get());
+ }
+ PostOrderVisit(realize->predicate, [&wb_predicate,
&reduction_loop_vars](const ObjectRef& obj) {
+ if (const auto* and_node = obj.as<AndNode>()) {
+ Array<PrimExpr> sub_exprs = {and_node->a, and_node->b};
+ for (PrimExpr sub_expr : sub_exprs) {
+ if (sub_expr->IsInstance<AndNode>()) {
+ continue;
+ }
+ bool is_reduction = [sub_expr, &reduction_loop_vars]() {
+ Array<Var> vars = UndefinedVars(sub_expr);
+ for (Var var : vars) {
+ if (reduction_loop_vars.find(var.get()) !=
reduction_loop_vars.end()) {
+ return true;
+ }
+ }
+ return false;
+ }();
+ if (!is_reduction) {
+ wb_predicate = wb_predicate && sub_expr;
+ }
+ }
+ return true;
+ }
+ return false;
+ });
+ if (wb_buffers[0].scope() != "local") {
+ for (const ForNode* loop : reduction_loops) {
+ if (loop->thread_binding.defined()) {
+ wb_predicate = wb_predicate && (loop->loop_var ==
IntImm(loop->loop_var->dtype, 0));
+ }
+ }
+ }
+ // Remove unused iter vars which introduced by blockize
+ // otherwise may generate duplicated for loops
+ Array<IterVar> iter_vars_used;
+ Array<PrimExpr> bindings_used;
+ auto f_inject = [&iter_vars, &bindings, &iter_vars_used,
+ &bindings_used](const VarNode* var) -> bool {
+ for (size_t i = 0; i < iter_vars.size(); ++i) {
+ const IterVar& iter_var = iter_vars[i];
+ if (iter_var->var.get() == var) {
+ if (std::find(iter_vars_used.begin(), iter_vars_used.end(),
iter_var) ==
+ iter_vars_used.end()) {
+ iter_vars_used.push_back(iter_var);
+ bindings_used.push_back(bindings[i]);
+ }
+ return true;
+ }
+ }
+ return false;
+ };
+ for (const BufferRegion& write : wb_regions) {
+ for (const Range& range : write->region) {
+ UsedIterVarCollector::Collect(range->min, f_inject);
+ UsedIterVarCollector::Collect(range->extent, f_inject);
+ }
+ }
+ stmts.push_back(BlockRealize(
+ /*iter_values=*/std::move(bindings_used),
+ /*predicate=*/wb_predicate,
+ /*block=*/
+ Block(/*iter_vars=*/std::move(iter_vars_used),
+ /*reads=*/std::move(ct_buffer_regions),
+ /*writes=*/std::move(wb_regions),
+ /*name_hint=*/block->name_hint + "_write_back",
+ /*body=*/n_buffers > 1 ? SeqStmt(wb_updates) : wb_updates[0])));
+ }
+ // Final step: Wrap all the above four statements with the reduction loops
bound to threadIdx
+ Stmt new_stmt = Stmt();
+ for (auto rit = reduction_loops.rbegin(); rit != reduction_loops.rend();
++rit) {
+ const ForNode* loop = *rit;
+ if (loop->thread_binding.defined()) {
+ // Colelct Loop vars between the reduction lops
+ std::vector<LoopVar> chain_loop_vars =
+ LoopVarCollector::Collect(loop->body, GetRef<Block>(block));
+ std::vector<LoopVar> used_chain_loop_vars_array;
+ if (HasChildBlocksChecker::Check(GetRef<Block>(block))) {
+ chain_loop_vars.clear();
+ Array<BlockRealize> child_blocks =
+ HasChildBlocksChecker::GetChildBlockRealizes(GetRef<Block>(block));
+ for (BlockRealize child_block : child_blocks) {
+ std::vector<LoopVar> child_loop_vars =
+ LoopVarCollector::Collect(loop->body, child_block->block);
+ chain_loop_vars.insert(chain_loop_vars.end(),
child_loop_vars.begin(),
+ child_loop_vars.end());
+ }
+ }
+
+ // Remove Unused Loop from the chain loops, otherwise may generate
duplicated for loops
+ for (auto it = chain_loop_vars.begin(); it != chain_loop_vars.end();
++it) {
+ Var target_var = (*it).loop_var;
+ auto f_find = [&target_var](const VarNode* var) -> bool {
+ if (target_var.get() == var) {
+ return true;
+ }
+ return false;
+ };
+ for (const Stmt& stmt : stmts) {
+ if (UsesVar(stmt, f_find)) {
+ used_chain_loop_vars_array.push_back(*it);
+ break;
+ }
+ }
+ }
+ chain_loop_vars = used_chain_loop_vars_array;
+
+ ObjectPtr<ForNode> n = make_object<ForNode>(*loop);
+ if (chain_loop_vars.size() == 0) {
+ stmts.insert(stmts.begin(), n->body);
+ new_stmt = SeqStmt::Flatten(std::move(stmts));
+ n->body = std::move(new_stmt);
+ new_stmt = For(n);
+ break;
+ } else {
+ new_stmt = SeqStmt::Flatten(std::move(stmts));
+ For new_for = For(chain_loop_vars.back().loop_var,
chain_loop_vars.back().min,
+ chain_loop_vars.back().extent,
chain_loop_vars.back().kind, new_stmt);
+
+ ObjectPtr<ForNode> current_loop = make_object<ForNode>(*new_for.get());
+ for (int i = chain_loop_vars.size() - 2; i >= 0; i--) {
+ new_for = For(chain_loop_vars[i].loop_var, chain_loop_vars[i].min,
+ chain_loop_vars[i].extent, chain_loop_vars[i].kind,
new_for);
+ }
+ new_stmt = SeqStmt::Flatten(std::move((SeqStmt({std::move(n->body),
std::move(new_for)}))));
+ n->body = std::move(new_stmt);
+ new_stmt = For(n);
+ break;
+ }
+ }
+ }
+ return new_stmt;
+}
+
+/*!
+ * \brief Inject the lowered warp evaluate allreduce block transformed from
the input reduction
+ * block
+ * \param realize The block-realize which contains the old reduction block
+ * \param ct_buffers The buffers to store cross-thread reduction results
+ * \param wb_buffers The buffers to store the final reduction results
+ * \param old_wb_indices The indices used to access the write-back buffers
when storing the final
+ * reduction results into the write-back buffers
+ * \param reducer The reduction function
+ * \param combiner_lhs The LHS values of the combiner
+ * \param reduction_loops The reduction loops
+ * \param used_block The block used in the warp evaluate (init block)
+ */
+Stmt InjectWarpEvaluateReductionBlock(const BlockRealizeNode* realize,
//
+ const Array<Buffer>& ct_buffers,
//
+ const Array<Buffer>& wb_buffers,
//
+ const Array<PrimExpr>& old_wb_indices,
//
+ const CommReducer& reducer,
//
+ const Array<PrimExpr>&
combiner_evaluate, //
+ const std::vector<const ForNode*>&
reduction_loops //
+ ) {
+ int n_buffers = wb_buffers.size();
+ const BlockNode* block = realize->block.get();
+ Buffer write_buffer = (*block->writes.begin())->buffer;
+ PrimExpr warp_size = write_buffer->shape[write_buffer->shape.size() - 2];
+ PrimExpr local_size = write_buffer->shape[write_buffer->shape.size() - 1];
+ // Create IterVars
+ Var v_lane_id = Var("v_lane_id");
+ Var v_local_id = Var("v_local_id");
+ // Create Bindings
+ Var ax_lane_id = Var("ax_lane_id");
+ Var ax_local_id = Var("ax_local_id");
+ // Create IterVar
+ IterVar lane_id = IterVar(Range(0, warp_size), v_lane_id, kDataPar,
"threadIdx.x");
+ IterVar local_id = IterVar(Range(0, local_size), v_local_id, kDataPar);
+
+ auto f_create_buffer_regions = [](Array<Buffer> buffers) {
+ Array<BufferRegion> regions;
+ regions.reserve(buffers.size());
+ for (const Buffer& buffer : buffers) {
+ regions.push_back(BufferRegion(buffer, {Range::FromMinExtent(0, 1)}));
+ }
+ return regions;
+ };
+
+ Array<BufferRegion> ct_buffer_regions = f_create_buffer_regions(ct_buffers);
+ Optional<Array<BufferRegion>> it_buffer_regions = NullOpt;
+ // In total, the block is transformed into at most 4 statements
+ // - Stmt 1: initialize the buffer for in-thread reduction
+ // - Stmt 2: do in-thread reduction
+ // - Stmt 3: do cross-thread reduction
+ // - Stmt 4: write cross-thread reduction result to the original buffer
+ Array<Stmt> stmts;
+ stmts.reserve(4);
+
+ // Stmt 3: do cross-thread reduction
+ {
+ // Step 3.1. Create the parameters to the intrinsic
+ Array<PrimExpr> parameters;
+
+ // Step 3.2. Create the block and the block-realize.
+ Array<IterVar> iter_vars = block->iter_vars;
+ Array<PrimExpr> bindings = realize->iter_values;
+ Array<BufferRegion> reads = block->writes;
+
+ // Blockized block should also be considered
+ if (HasChildBlocksChecker::Check(GetRef<Block>(block))) {
+ Array<BlockRealize> child_blocks =
+ HasChildBlocksChecker::GetChildBlockRealizes(GetRef<Block>(block));
+
+ ICHECK_GT(child_blocks.size(), 0) << "Child blocks should be more than
0";
+
+ // If has ChildBlocks, the reads should be analyzed from the child blocks
+ reads.clear();
+ for (BlockRealize child_block : child_blocks) {
+ Array<IterVar> child_iter_vars = child_block->block->iter_vars;
+ Array<PrimExpr> child_bindings = child_block->iter_values;
+ iter_vars.insert(iter_vars.end(), child_iter_vars.begin(),
child_iter_vars.end());
+ bindings.insert(bindings.end(), child_bindings.begin(),
child_bindings.end());
+ reads.insert(reads.end(), child_block->block->writes.begin(),
+ child_block->block->writes.end());
+ }
+ }
+
+ // Remove unused iter vars which introduced by blockize
+ // otherwise may generate duplicated for loops
+ Array<IterVar> iter_vars_used;
+ Array<PrimExpr> bindings_used;
+ auto f_inject = [&iter_vars, &bindings, &iter_vars_used,
+ &bindings_used](const VarNode* var) -> bool {
+ for (size_t i = 0; i < iter_vars.size(); ++i) {
+ const IterVar& iter_var = iter_vars[i];
+ if (iter_var->var.get() == var) {
+ if (std::find(iter_vars_used.begin(), iter_vars_used.end(),
iter_var) ==
+ iter_vars_used.end()) {
+ iter_vars_used.push_back(iter_var);
+ bindings_used.push_back(bindings[i]);
+ }
+ return true;
+ }
+ }
+ return false;
+ };
+ for (const BufferRegion& read : reads) {
+ for (const Range& range : read->region) {
+ UsedIterVarCollector::Collect(range->min, f_inject);
+ UsedIterVarCollector::Collect(range->extent, f_inject);
+ }
+ }
+
+ // Create the vars for the warp
+ // Collect the Warp Information from the Buffer
+ // The threadIdx and Warp Information is stored in the last two dimensions
of the write buffer
+
+ iter_vars_used.push_back(lane_id);
+ iter_vars_used.push_back(local_id);
+
+ bindings_used.push_back(ax_lane_id);
+ bindings_used.push_back(ax_local_id);
+
+ Array<PrimExpr> combiner_warp;
+ for (size_t i = 0; i < combiner_evaluate.size(); i++) {
+ Array<PrimExpr> new_indices; // indices that maintain the new indices
with warp information
+ for (size_t j = 0; j <
combiner_evaluate[i].as<BufferLoadNode>()->indices.size() - 2; j++) {
+
new_indices.push_back(combiner_evaluate[i].as<BufferLoadNode>()->indices[j]);
+ }
+ new_indices.push_back(v_lane_id);
+ new_indices.push_back(v_local_id);
+ combiner_warp.push_back(
+ BufferLoad(combiner_evaluate[i].as<BufferLoadNode>()->buffer,
new_indices));
+ }
+ // Create the parameters to the intrinsic
+ parameters.reserve(reduction_loops.size() + 4);
+ // 1-st argument: number of buffers
+ parameters.push_back(make_const(DataType::UInt(32), n_buffers));
+ // Next `n_buffers` arguments: sources
+ parameters.insert(parameters.end(), combiner_warp.begin(),
combiner_warp.end());
+ // Next argument: predicate
+ parameters.push_back(const_true());
+ // Next `n_buffers` arguments: destinations
+ for (int i = 0; i < n_buffers; ++i) {
+ parameters.push_back(BufferLoad(ct_buffers[i], {0}));
+ }
+ // Next arguments: all the reduction threads
+ for (const ForNode* reduction_loop : reduction_loops) {
+ if (reduction_loop->thread_binding.defined()) {
+ parameters.push_back(reduction_loop->loop_var);
+ }
+ }
+ // update param indices
+ Block cross_thread_block =
+ Block(/*iter_vars=*/std::move(iter_vars_used),
+ /*reads=*/std::move(reads),
+ /*writes=*/ct_buffer_regions,
+ /*name_hint=*/block->name_hint + "_cross_thread",
+ /*body=*/
+ AttrStmt(/*node=*/reducer,
+ /*attr_key=*/tir::attr::reduce_scope,
+ /*value=*/make_zero(DataType::Handle()),
+ /*body=*/
+ Evaluate(Call(/*dtype=*/DataType::Handle(),
+
/*op=*/tir::builtin::tvm_thread_allreduce(),
+ /*args=*/std::move(parameters)))));
+
+ ObjectPtr<BlockNode> cross_thread_block_node =
+ make_object<BlockNode>(*cross_thread_block.operator->());
+ cross_thread_block_node->annotations.Set(kIsCrossThreadReductionApplied,
Bool(true));
+ stmts.push_back(BlockRealize(
+ /*iter_values=*/std::move(bindings_used),
+ /*predicate=*/const_true(),
+ /*block=*/cross_thread_block));
+ }
+ // Stmt 4: write cross-thread reduction result to the original buffer
+ {
+ ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size());
+ int n_iter = static_cast<int>(block->iter_vars.size());
+ Array<IterVar> iter_vars;
+ Array<PrimExpr> bindings;
+ Array<Range> write_region = block->writes[0]->region;
+ iter_vars.reserve(n_iter);
+ bindings.reserve(n_iter);
+ for (int i = 0; i < n_iter; ++i) {
+ const IterVar& iter_var = block->iter_vars[i];
+ const PrimExpr& binding = realize->iter_values[i];
+ if (iter_var->iter_type != kCommReduce) {
+ iter_vars.push_back(iter_var);
+ bindings.push_back(binding);
+ }
+ }
+
+ if (HasChildBlocksChecker::Check(GetRef<Block>(block))) {
+ Array<BlockRealize> child_blocks =
+ HasChildBlocksChecker::GetChildBlockRealizes(GetRef<Block>(block));
+ ICHECK_GT(child_blocks.size(), 0) << "Child blocks should be more than
0";
+
+ for (BlockRealize child_block : child_blocks) {
+ Array<IterVar> child_iter_vars = child_block->block->iter_vars;
+ Array<PrimExpr> child_bindings = child_block->iter_values;
+ iter_vars.insert(iter_vars.end(), child_iter_vars.begin(),
child_iter_vars.end());
+ bindings.insert(bindings.end(), child_bindings.begin(),
child_bindings.end());
+ write_region = child_block->block->writes[0]->region;
+ }
+ }
+ Array<Stmt> wb_updates;
+ Array<BufferRegion> wb_regions;
+ wb_updates.reserve(n_buffers);
+ wb_regions.reserve(n_buffers);
+
+ Array<Range> region = write_region;
+ Array<PrimExpr> wb_indices;
+ for (size_t i = 0; i < block->writes[0]->region.size() - 2; i++) {
+ Range range = block->writes[0]->region[i];
+ wb_indices.push_back(range->min);
+ }
+ // Append the warp and local id to the indices
+ wb_indices.push_back(v_lane_id);
+ wb_indices.push_back(v_local_id);
+ for (int i = 0; i < n_buffers; ++i) {
+ wb_updates.push_back(
+ BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}),
wb_indices));
+ wb_regions.push_back(BufferRegion(wb_buffers[i], region));
+ }
+
+ // Construct the predicate of the write-back block. It is the conjunction
of
+ // - each predicate clause of the original block which contains spatial
loop var, and
+ // - `t == 0` for each reduction thread dim when the write-back buffer is
not local.
+ PrimExpr wb_predicate = const_true();
+ std::unordered_set<const VarNode*> reduction_loop_vars;
+ reduction_loop_vars.reserve(reduction_loops.size());
+ for (const ForNode* reduction_loop : reduction_loops) {
+ reduction_loop_vars.insert(reduction_loop->loop_var.get());
+ }
+ PostOrderVisit(realize->predicate, [&wb_predicate,
&reduction_loop_vars](const ObjectRef& obj) {
+ if (const auto* and_node = obj.as<AndNode>()) {
+ Array<PrimExpr> sub_exprs = {and_node->a, and_node->b};
+ for (PrimExpr sub_expr : sub_exprs) {
+ if (sub_expr->IsInstance<AndNode>()) {
+ continue;
+ }
+ bool is_reduction = [sub_expr, &reduction_loop_vars]() {
+ Array<Var> vars = UndefinedVars(sub_expr);
+ for (Var var : vars) {
+ if (reduction_loop_vars.find(var.get()) !=
reduction_loop_vars.end()) {
+ return true;
+ }
+ }
+ return false;
+ }();
+ if (!is_reduction) {
+ wb_predicate = wb_predicate && sub_expr;
+ }
+ }
+ return true;
+ }
+ return false;
+ });
+ if (wb_buffers[0].scope() != "local") {
+ for (const ForNode* loop : reduction_loops) {
+ if (loop->thread_binding.defined()) {
+ wb_predicate = wb_predicate && (loop->loop_var ==
IntImm(loop->loop_var->dtype, 0));
+ }
+ }
+ }
+ // Remove unused iter vars which introduced by blockize
+ // otherwise may generate duplicated for loops
+ Array<IterVar> iter_vars_used;
+ Array<PrimExpr> bindings_used;
+ auto f_inject = [&iter_vars, &bindings, &iter_vars_used,
+ &bindings_used](const VarNode* var) -> bool {
+ for (size_t i = 0; i < iter_vars.size(); ++i) {
+ const IterVar& iter_var = iter_vars[i];
+ if (iter_var->var.get() == var) {
+ if (std::find(iter_vars_used.begin(), iter_vars_used.end(),
iter_var) ==
+ iter_vars_used.end()) {
+ iter_vars_used.push_back(iter_var);
+ bindings_used.push_back(bindings[i]);
+ }
+ return true;
+ }
+ }
+ return false;
+ };
+ for (const BufferRegion& write : wb_regions) {
+ for (const Range& range : write->region) {
+ UsedIterVarCollector::Collect(range->min, f_inject);
+ UsedIterVarCollector::Collect(range->extent, f_inject);
+ }
+ }
+
+ iter_vars_used.push_back(IterVar(Range(0, warp_size), v_lane_id, kDataPar,
"threadIdx.x"));
+ iter_vars_used.push_back(IterVar(Range(0, local_size), v_local_id,
kDataPar));
+
+ bindings_used.push_back(ax_lane_id);
+ bindings_used.push_back(ax_local_id);
+ stmts.push_back(BlockRealize(
+ /*iter_values=*/std::move(bindings_used),
+ /*predicate=*/wb_predicate,
+ /*block=*/
+ Block(/*iter_vars=*/std::move(iter_vars_used),
+ /*reads=*/std::move(ct_buffer_regions),
+ /*writes=*/std::move(wb_regions),
+ /*name_hint=*/block->name_hint + "_write_back",
+ /*body=*/n_buffers > 1 ? SeqStmt(wb_updates) : wb_updates[0])));
+ }
+ // Final step: Wrap all the above four statements with the reduction loops
bound to threadIdx
+ Stmt new_stmt = Stmt();
+ for (auto rit = reduction_loops.rbegin(); rit != reduction_loops.rend();
++rit) {
+ const ForNode* loop = *rit;
+ if (loop->thread_binding.defined()) {
+ // Colelct Loop vars between the reduction lops
+ std::vector<LoopVar> chain_loop_vars =
+ LoopVarCollector::Collect(loop->body, GetRef<Block>(block));
+ std::vector<LoopVar> used_chain_loop_vars_array;
+ if (HasChildBlocksChecker::Check(GetRef<Block>(block))) {
+ chain_loop_vars.clear();
+ Array<BlockRealize> child_blocks =
+ HasChildBlocksChecker::GetChildBlockRealizes(GetRef<Block>(block));
+ for (BlockRealize child_block : child_blocks) {
+ std::vector<LoopVar> child_loop_vars =
+ LoopVarCollector::Collect(loop->body, child_block->block);
+ chain_loop_vars.insert(chain_loop_vars.end(),
child_loop_vars.begin(),
+ child_loop_vars.end());
+ }
+ }
+
+ // Remove Unused Loop from the chain loops, otherwise may generate
duplicated for loops
+ for (auto it = chain_loop_vars.begin(); it != chain_loop_vars.end();
++it) {
+ Var target_var = (*it).loop_var;
+ auto f_find = [&target_var](const VarNode* var) -> bool {
+ if (target_var.get() == var) {
+ return true;
+ }
+ return false;
+ };
+ for (const Stmt& stmt : stmts) {
+ if (UsesVar(stmt, f_find)) {
+ used_chain_loop_vars_array.push_back(*it);
+ break;
+ }
+ }
+ }
+ chain_loop_vars = used_chain_loop_vars_array;
+ // append warp related loops
+ chain_loop_vars.push_back(LoopVar(ax_lane_id,
IntImm(loop->loop_var->dtype, 0), warp_size,
+ ForKind::kThreadBinding));
+ chain_loop_vars.push_back(
+ LoopVar(ax_local_id, IntImm(loop->loop_var->dtype, 0), local_size,
ForKind::kSerial));
+ ObjectPtr<ForNode> n = make_object<ForNode>(*loop);
+ if (chain_loop_vars.size() == 0) {
+ stmts.insert(stmts.begin(), n->body);
+ new_stmt = SeqStmt::Flatten(std::move(stmts));
+ n->body = std::move(new_stmt);
+ new_stmt = For(n);
+ break;
+ } else {
+ new_stmt = SeqStmt::Flatten(std::move(stmts));
+ For new_for = For(chain_loop_vars.back().loop_var,
chain_loop_vars.back().min,
+ chain_loop_vars.back().extent,
chain_loop_vars.back().kind, new_stmt);
+
+ ObjectPtr<ForNode> current_loop = make_object<ForNode>(*new_for.get());
+ for (int i = chain_loop_vars.size() - 2; i >= 0; i--) {
+ LoopVar loop_var = chain_loop_vars[i];
+ if (loop_var.kind == ForKind::kThreadBinding) {
+ new_for = For(loop_var.loop_var, loop_var.min, loop_var.extent,
loop_var.kind, new_for,
+ lane_id);
+ } else {
+ new_for = For(loop_var.loop_var, loop_var.min, loop_var.extent,
loop_var.kind, new_for);
+ }
+ }
+ new_stmt = SeqStmt::Flatten(std::move((SeqStmt({std::move(n->body),
std::move(new_for)}))));
+ n->body = std::move(new_stmt);
+ new_stmt = For(n);
+ break;
+ }
+ }
+ }
+ return new_stmt;
+}
+
+/*!
+ * \brief Detect cross-thread reduction pattern and then transform
+ */
+class CrossThreadReductionTransformer : public StmtMutator {
+ private:
+ // Check if the input block needs cross-thread reduction.
+ std::vector<const ForNode*> NeedCrossThreadReduction(const BlockRealizeNode*
realize) {
+ // Step 0. If the block is the root block, just return.
+ if (block_stack_.empty()) {
+ return {};
+ }
+
+ // Step 1. If the block is not a reduction block, cross-thread reduction
is not needed.
+ if (!IsReductionBlock(GetRef<BlockRealize>(realize), loop_range_map_,
+ GetRef<Block>(block_stack_.back()), &analyzer_)) {
+ return {};
+ }
+ // Step 2. Collect all the vars that appear in the bindings of reduction
block iters.
+ std::unordered_set<const VarNode*> reduction_vars;
+ GetVarsTouchedByBlockIters(GetRef<BlockRealize>(realize), nullptr,
&reduction_vars);
+
+ // Step 3. Collect the loops whose loop vars appear in the bindings of
reduction block iters.
+ // We call these loops "reduction-related".
+ // Step 4. See whether at least one reduction-related loop is bound to
thread axis in GPU - if
+ // so, cross-thread reduction is needed. If none of the reduction-related
loops is bound to
+ // thread axis, cross-thread reduction is not needed for the input block.
+ bool need = false;
+ std::vector<const ForNode*> reduction_loops;
+ for (const ForNode* loop : loop_stack_) {
+ if (reduction_vars.count(loop->loop_var.get())) {
+ // Step 3. Collect the loop.
+ reduction_loops.push_back(loop);
+ // Step 4. See whether the loop is bound to some thread axis.
+ if (loop->thread_binding.defined()) {
+ need = true;
+ }
+ }
+ }
+ return need ? reduction_loops : std::vector<const ForNode*>{};
+ }
+
+ // Check if the input block needs thread broadcast rewrite.
+ // One block needs broadcast rewrite when
+ // 1. it consumes a buffer produced by cross-thread reduction under
+ // the same kernel (i.e., same group of blockIdx),
+ // 2. it writes to non-local memory,
+ // 3. at least one of the reduction thread vars of the cross-thread reduction
+ // is free to this block (i.e., not bound to the block).
+ std::vector<std::pair<ThreadScope, Range>> NeedCrossThreadBroadcast(
+ const BlockRealizeNode* realize) {
+ Block block = realize->block;
+
+ // If the block writes to local memory, no rewrite is needed.
+ for (BufferRegion write_region : block->writes) {
+ if (write_region->buffer.scope() == "local") {
+ return {};
+ }
+ }
+
+ // Find out the reduction threads for the read-buffers which are produced
by
+ // cross-thread reduction.
+ std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual>
thread2range;
+ for (BufferRegion read_region : block->reads) {
+ auto buf_it = crt_buf2threads_.find(read_region->buffer.get());
+ if (buf_it == crt_buf2threads_.end()) {
+ continue;
+ }
+ for (auto[scope, range] : buf_it->second) {
Review Comment:
requires `clang-format`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]